supervisely 6.73.452__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +25 -1
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +13 -12
- supervisely/api/annotation_api.py +6 -3
- supervisely/api/api.py +2 -0
- supervisely/api/app_api.py +10 -1
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +28 -0
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +4 -0
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +5 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +22 -2
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +795 -199
- supervisely/nn/inference/inference_request.py +42 -9
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +113 -34
- supervisely/nn/inference/tracking/tracker_interface.py +7 -2
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/prediction_dto.py +12 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/visualize.py +87 -90
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/train_app.py +28 -29
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +40 -11
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/METADATA +56 -39
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/RECORD +189 -142
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.452.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -5,12 +5,14 @@ import asyncio
|
|
|
5
5
|
import inspect
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
+
import queue
|
|
8
9
|
import re
|
|
9
10
|
import shutil
|
|
10
11
|
import subprocess
|
|
11
12
|
import tempfile
|
|
12
13
|
import threading
|
|
13
14
|
import time
|
|
15
|
+
import uuid
|
|
14
16
|
from collections import OrderedDict, defaultdict
|
|
15
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
16
18
|
from dataclasses import asdict, dataclass
|
|
@@ -52,6 +54,7 @@ from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
|
52
54
|
from supervisely.api.api import Api, ApiField
|
|
53
55
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
54
56
|
from supervisely.api.image_api import ImageInfo
|
|
57
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
55
58
|
from supervisely.app.content import get_data_dir
|
|
56
59
|
from supervisely.app.fastapi.subapp import (
|
|
57
60
|
Application,
|
|
@@ -67,6 +70,7 @@ from supervisely.decorators.inference import (
|
|
|
67
70
|
process_images_batch_sliding_window,
|
|
68
71
|
)
|
|
69
72
|
from supervisely.geometry.any_geometry import AnyGeometry
|
|
73
|
+
from supervisely.geometry.geometry import Geometry
|
|
70
74
|
from supervisely.imaging.color import get_predefined_colors
|
|
71
75
|
from supervisely.io.fs import list_files
|
|
72
76
|
from supervisely.nn.experiments import ExperimentInfo
|
|
@@ -75,7 +79,7 @@ from supervisely.nn.inference.inference_request import (
|
|
|
75
79
|
InferenceRequest,
|
|
76
80
|
InferenceRequestsManager,
|
|
77
81
|
)
|
|
78
|
-
from supervisely.nn.inference.uploader import Uploader
|
|
82
|
+
from supervisely.nn.inference.uploader import Downloader, Uploader
|
|
79
83
|
from supervisely.nn.model.model_api import ModelAPI, Prediction
|
|
80
84
|
from supervisely.nn.prediction_dto import Prediction as PredictionDTO
|
|
81
85
|
from supervisely.nn.utils import (
|
|
@@ -94,6 +98,17 @@ from supervisely.project.project_meta import ProjectMeta
|
|
|
94
98
|
from supervisely.sly_logger import logger
|
|
95
99
|
from supervisely.task.progress import Progress
|
|
96
100
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
101
|
+
from supervisely.video_annotation.frame import Frame
|
|
102
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
|
103
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
104
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
105
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
|
106
|
+
from supervisely.video_annotation.video_object import VideoObject
|
|
107
|
+
from supervisely.video_annotation.video_object_collection import (
|
|
108
|
+
VideoObject,
|
|
109
|
+
VideoObjectCollection,
|
|
110
|
+
)
|
|
111
|
+
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
|
97
112
|
|
|
98
113
|
try:
|
|
99
114
|
from typing import Literal
|
|
@@ -140,6 +155,7 @@ class Inference:
|
|
|
140
155
|
"""Default batch size for inference"""
|
|
141
156
|
INFERENCE_SETTINGS: str = None
|
|
142
157
|
"""Path to file with custom inference settings"""
|
|
158
|
+
DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
|
|
143
159
|
|
|
144
160
|
def __init__(
|
|
145
161
|
self,
|
|
@@ -193,7 +209,6 @@ class Inference:
|
|
|
193
209
|
self._task_id = None
|
|
194
210
|
self._sliding_window_mode = sliding_window_mode
|
|
195
211
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
196
|
-
self._tracker = None
|
|
197
212
|
self._hardware: str = None
|
|
198
213
|
if custom_inference_settings is None:
|
|
199
214
|
if self.INFERENCE_SETTINGS is not None:
|
|
@@ -427,7 +442,7 @@ class Inference:
|
|
|
427
442
|
|
|
428
443
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
429
444
|
except Exception as e:
|
|
430
|
-
logger.
|
|
445
|
+
logger.warning(
|
|
431
446
|
f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
|
|
432
447
|
)
|
|
433
448
|
device = "cpu"
|
|
@@ -734,15 +749,15 @@ class Inference:
|
|
|
734
749
|
for model in self.pretrained_models:
|
|
735
750
|
model_meta = model.get("meta")
|
|
736
751
|
if model_meta is not None:
|
|
737
|
-
|
|
738
|
-
if
|
|
739
|
-
if
|
|
752
|
+
this_model_name = model_meta.get("model_name")
|
|
753
|
+
if this_model_name is not None:
|
|
754
|
+
if this_model_name.lower() == model_name.lower():
|
|
740
755
|
selected_model = model
|
|
741
756
|
break
|
|
742
757
|
else:
|
|
743
|
-
|
|
744
|
-
if
|
|
745
|
-
if
|
|
758
|
+
this_model_name = model.get("model_name")
|
|
759
|
+
if this_model_name is not None:
|
|
760
|
+
if this_model_name.lower() == model_name.lower():
|
|
746
761
|
selected_model = model
|
|
747
762
|
break
|
|
748
763
|
|
|
@@ -1359,6 +1374,7 @@ class Inference:
|
|
|
1359
1374
|
|
|
1360
1375
|
if tracker == "botsort":
|
|
1361
1376
|
from supervisely.nn.tracker import BotSortTracker
|
|
1377
|
+
|
|
1362
1378
|
device = tracker_settings.get("device", self.device)
|
|
1363
1379
|
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1364
1380
|
return BotSortTracker(settings=tracker_settings, device=device)
|
|
@@ -1375,15 +1391,15 @@ class Inference:
|
|
|
1375
1391
|
if classes is not None:
|
|
1376
1392
|
num_classes = len(classes)
|
|
1377
1393
|
except NotImplementedError:
|
|
1378
|
-
logger.
|
|
1394
|
+
logger.warning(f"get_classes() function not implemented for {type(self)} object.")
|
|
1379
1395
|
except AttributeError:
|
|
1380
|
-
logger.
|
|
1396
|
+
logger.warning("Probably, get_classes() function not working without model deploy.")
|
|
1381
1397
|
except Exception as exc:
|
|
1382
|
-
logger.
|
|
1398
|
+
logger.warning("Unknown exception. Please, contact support")
|
|
1383
1399
|
logger.exception(exc)
|
|
1384
1400
|
|
|
1385
1401
|
if num_classes is None:
|
|
1386
|
-
logger.
|
|
1402
|
+
logger.warning(f"get_classes() function return {classes}; skip classes processing.")
|
|
1387
1403
|
|
|
1388
1404
|
return {
|
|
1389
1405
|
"app_name": get_name_from_env(default="Neural Network Serving"),
|
|
@@ -1401,6 +1417,42 @@ class Inference:
|
|
|
1401
1417
|
|
|
1402
1418
|
# pylint: enable=method-hidden
|
|
1403
1419
|
|
|
1420
|
+
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
|
1421
|
+
"""
|
|
1422
|
+
Get default parameters for all available tracking algorithms.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
{"botsort": {"track_high_thresh": 0.6, ...}}
|
|
1426
|
+
Empty dict if tracking not supported.
|
|
1427
|
+
"""
|
|
1428
|
+
info = self.get_info()
|
|
1429
|
+
trackers_params = {}
|
|
1430
|
+
|
|
1431
|
+
tracking_support = info.get("tracking_on_videos_support")
|
|
1432
|
+
if not tracking_support:
|
|
1433
|
+
return trackers_params
|
|
1434
|
+
|
|
1435
|
+
tracking_algorithms = info.get("tracking_algorithms", [])
|
|
1436
|
+
|
|
1437
|
+
for tracker_name in tracking_algorithms:
|
|
1438
|
+
try:
|
|
1439
|
+
if tracker_name == "botsort":
|
|
1440
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1441
|
+
|
|
1442
|
+
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
|
1443
|
+
# Add other trackers here as elif blocks
|
|
1444
|
+
else:
|
|
1445
|
+
logger.debug(f"Tracker '{tracker_name}' not implemented")
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
logger.warning(f"Failed to get params for '{tracker_name}': {e}")
|
|
1448
|
+
|
|
1449
|
+
INTERNAL_FIELDS = {"device", "fps"}
|
|
1450
|
+
for tracker_name, params in trackers_params.items():
|
|
1451
|
+
trackers_params[tracker_name] = {
|
|
1452
|
+
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
|
1453
|
+
}
|
|
1454
|
+
return trackers_params
|
|
1455
|
+
|
|
1404
1456
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
|
1405
1457
|
hr_info = {}
|
|
1406
1458
|
info = self.get_info()
|
|
@@ -1952,7 +2004,7 @@ class Inference:
|
|
|
1952
2004
|
else:
|
|
1953
2005
|
n_frames = frames_reader.frames_count()
|
|
1954
2006
|
|
|
1955
|
-
|
|
2007
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
1956
2008
|
|
|
1957
2009
|
progress_total = (n_frames + step - 1) // step
|
|
1958
2010
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
@@ -1978,8 +2030,8 @@ class Inference:
|
|
|
1978
2030
|
settings=inference_settings,
|
|
1979
2031
|
)
|
|
1980
2032
|
|
|
1981
|
-
if
|
|
1982
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2033
|
+
if inference_request.tracker is not None:
|
|
2034
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
1983
2035
|
|
|
1984
2036
|
predictions = [
|
|
1985
2037
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
@@ -1994,10 +2046,9 @@ class Inference:
|
|
|
1994
2046
|
inference_request.done(len(batch_results))
|
|
1995
2047
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1996
2048
|
video_ann_json = None
|
|
1997
|
-
if
|
|
2049
|
+
if inference_request.tracker is not None:
|
|
1998
2050
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1999
|
-
|
|
2000
|
-
video_ann_json = self._tracker.video_annotation.to_json()
|
|
2051
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2001
2052
|
inference_request.done()
|
|
2002
2053
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
2003
2054
|
inference_request.final_result = result.copy()
|
|
@@ -2029,7 +2080,7 @@ class Inference:
|
|
|
2029
2080
|
upload_mode = state.get("upload_mode", None)
|
|
2030
2081
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2031
2082
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2032
|
-
iou_merge_threshold = 0.
|
|
2083
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
|
|
2033
2084
|
|
|
2034
2085
|
images_infos = api.image.get_info_by_id_batch(image_ids)
|
|
2035
2086
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
|
@@ -2071,14 +2122,9 @@ class Inference:
|
|
|
2071
2122
|
output_dataset_id
|
|
2072
2123
|
] = output_dataset_info
|
|
2073
2124
|
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
dataset_image_infos[image_info.dataset_id].append(image_info)
|
|
2078
|
-
for dataset_id, ds_image_infos in dataset_image_infos.items():
|
|
2079
|
-
self.cache.run_cache_task_manually(
|
|
2080
|
-
api, [info.id for info in ds_image_infos], dataset_id=dataset_id
|
|
2081
|
-
)
|
|
2125
|
+
def download_f(item: int):
|
|
2126
|
+
self.cache.download_image(api, item)
|
|
2127
|
+
return item
|
|
2082
2128
|
|
|
2083
2129
|
_upload_predictions = partial(
|
|
2084
2130
|
self.upload_predictions,
|
|
@@ -2094,7 +2140,9 @@ class Inference:
|
|
|
2094
2140
|
)
|
|
2095
2141
|
|
|
2096
2142
|
_add_results_to_request = partial(
|
|
2097
|
-
self.add_results_to_request,
|
|
2143
|
+
self.add_results_to_request,
|
|
2144
|
+
inference_request=inference_request,
|
|
2145
|
+
progress_cb=inference_request.done,
|
|
2098
2146
|
)
|
|
2099
2147
|
|
|
2100
2148
|
if upload_mode is None:
|
|
@@ -2103,40 +2151,60 @@ class Inference:
|
|
|
2103
2151
|
upload_f = _upload_predictions
|
|
2104
2152
|
|
|
2105
2153
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
|
|
2154
|
+
download_workers = max(8, min(batch_size, 64))
|
|
2106
2155
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
2107
|
-
|
|
2108
|
-
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
|
|
2112
|
-
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
)
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2156
|
+
with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
|
|
2157
|
+
for image_id in image_ids:
|
|
2158
|
+
downloader.put(image_id)
|
|
2159
|
+
downloader.next(100)
|
|
2160
|
+
for image_ids_batch in batched(image_ids, batch_size=batch_size):
|
|
2161
|
+
if uploader.has_exception():
|
|
2162
|
+
exception = uploader.exception
|
|
2163
|
+
raise exception
|
|
2164
|
+
if inference_request.is_stopped():
|
|
2165
|
+
logger.debug(
|
|
2166
|
+
f"Cancelling inference...",
|
|
2167
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2168
|
+
)
|
|
2169
|
+
break
|
|
2170
|
+
if inference_request.is_paused():
|
|
2171
|
+
logger.info("Inference request is paused. Waiting...")
|
|
2172
|
+
while inference_request.is_paused():
|
|
2173
|
+
if (
|
|
2174
|
+
inference_request.paused_for()
|
|
2175
|
+
> inference_request.PAUSE_SLEEP_MAX_WAIT
|
|
2176
|
+
):
|
|
2177
|
+
logger.info(
|
|
2178
|
+
"Inference request has been paused for too long. Cancelling..."
|
|
2179
|
+
)
|
|
2180
|
+
raise RuntimeError("Inference request cancelled due to long pause.")
|
|
2181
|
+
time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
|
|
2123
2182
|
|
|
2124
|
-
|
|
2125
|
-
|
|
2126
|
-
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
name=image_info.name,
|
|
2132
|
-
image_id=image_info.id,
|
|
2133
|
-
dataset_id=image_info.dataset_id,
|
|
2134
|
-
project_id=dataset_info.project_id,
|
|
2183
|
+
images_nps = [
|
|
2184
|
+
self.cache.download_image(api, img_id) for img_id in image_ids_batch
|
|
2185
|
+
]
|
|
2186
|
+
downloader.next(len(image_ids_batch))
|
|
2187
|
+
anns, slides_data = self._inference_auto(
|
|
2188
|
+
source=images_nps,
|
|
2189
|
+
settings=inference_settings,
|
|
2135
2190
|
)
|
|
2136
|
-
prediction.extra_data["slides_data"] = this_slides_data
|
|
2137
|
-
batch_predictions.append(prediction)
|
|
2138
2191
|
|
|
2139
|
-
|
|
2192
|
+
batch_predictions = []
|
|
2193
|
+
for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
|
|
2194
|
+
image_info: ImageInfo = images_infos_dict[image_id]
|
|
2195
|
+
dataset_info = dataset_infos_dict[image_info.dataset_id]
|
|
2196
|
+
prediction = Prediction(
|
|
2197
|
+
ann,
|
|
2198
|
+
model_meta=self.model_meta,
|
|
2199
|
+
name=image_info.name,
|
|
2200
|
+
image_id=image_info.id,
|
|
2201
|
+
dataset_id=image_info.dataset_id,
|
|
2202
|
+
project_id=dataset_info.project_id,
|
|
2203
|
+
)
|
|
2204
|
+
prediction.extra_data["slides_data"] = this_slides_data
|
|
2205
|
+
batch_predictions.append(prediction)
|
|
2206
|
+
|
|
2207
|
+
uploader.put(batch_predictions)
|
|
2140
2208
|
|
|
2141
2209
|
def _inference_video_id(
|
|
2142
2210
|
self,
|
|
@@ -2181,7 +2249,7 @@ class Inference:
|
|
|
2181
2249
|
else:
|
|
2182
2250
|
n_frames = video_info.frames_count
|
|
2183
2251
|
|
|
2184
|
-
|
|
2252
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2185
2253
|
|
|
2186
2254
|
logger.debug(
|
|
2187
2255
|
f"Video info:",
|
|
@@ -2218,8 +2286,8 @@ class Inference:
|
|
|
2218
2286
|
settings=inference_settings,
|
|
2219
2287
|
)
|
|
2220
2288
|
|
|
2221
|
-
if
|
|
2222
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2289
|
+
if inference_request.tracker is not None:
|
|
2290
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2223
2291
|
|
|
2224
2292
|
predictions = [
|
|
2225
2293
|
Prediction(
|
|
@@ -2228,8 +2296,8 @@ class Inference:
|
|
|
2228
2296
|
frame_index=frame_index,
|
|
2229
2297
|
video_id=video_info.id,
|
|
2230
2298
|
dataset_id=video_info.dataset_id,
|
|
2231
|
-
|
|
2232
|
-
|
|
2299
|
+
project_id=video_info.project_id,
|
|
2300
|
+
)
|
|
2233
2301
|
for ann, frame_index in zip(anns, batch)
|
|
2234
2302
|
]
|
|
2235
2303
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
@@ -2240,9 +2308,169 @@ class Inference:
|
|
|
2240
2308
|
inference_request.done(len(batch_results))
|
|
2241
2309
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
2242
2310
|
video_ann_json = None
|
|
2243
|
-
if
|
|
2311
|
+
if inference_request.tracker is not None:
|
|
2312
|
+
inference_request.set_stage("Postprocess...", 0, progress_total)
|
|
2313
|
+
|
|
2314
|
+
video_ann_json = inference_request.tracker.create_video_annotation(
|
|
2315
|
+
video_info.frames_count,
|
|
2316
|
+
start_frame_index,
|
|
2317
|
+
step=step,
|
|
2318
|
+
progress_cb=inference_request.done,
|
|
2319
|
+
).to_json()
|
|
2320
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2321
|
+
return video_ann_json
|
|
2322
|
+
|
|
2323
|
+
def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2324
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
2325
|
+
inference_settings = self._get_inference_settings(state)
|
|
2326
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
2327
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
2328
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2329
|
+
if video_id is None:
|
|
2330
|
+
raise ValueError("Video id is not provided")
|
|
2331
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
2332
|
+
start_frame_index = get_value_for_keys(
|
|
2333
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2334
|
+
)
|
|
2335
|
+
if start_frame_index is None:
|
|
2336
|
+
start_frame_index = 0
|
|
2337
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
2338
|
+
if step is None:
|
|
2339
|
+
step = 1
|
|
2340
|
+
end_frame_index = get_value_for_keys(
|
|
2341
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
2342
|
+
)
|
|
2343
|
+
duration = state.get("duration", None)
|
|
2344
|
+
frames_count = get_value_for_keys(
|
|
2345
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
2346
|
+
)
|
|
2347
|
+
tracking = state.get("tracker", None)
|
|
2348
|
+
direction = state.get("direction", "forward")
|
|
2349
|
+
direction = 1 if direction == "forward" else -1
|
|
2350
|
+
track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
|
|
2351
|
+
|
|
2352
|
+
if frames_count is not None:
|
|
2353
|
+
n_frames = frames_count
|
|
2354
|
+
elif end_frame_index is not None:
|
|
2355
|
+
n_frames = end_frame_index - start_frame_index
|
|
2356
|
+
elif duration is not None:
|
|
2357
|
+
fps = video_info.frames_count / video_info.duration
|
|
2358
|
+
n_frames = int(duration * fps)
|
|
2359
|
+
else:
|
|
2360
|
+
n_frames = video_info.frames_count
|
|
2361
|
+
|
|
2362
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2363
|
+
|
|
2364
|
+
logger.debug(
|
|
2365
|
+
f"Video info:",
|
|
2366
|
+
extra=dict(
|
|
2367
|
+
w=video_info.frame_width,
|
|
2368
|
+
h=video_info.frame_height,
|
|
2369
|
+
start_frame_index=start_frame_index,
|
|
2370
|
+
n_frames=n_frames,
|
|
2371
|
+
),
|
|
2372
|
+
)
|
|
2373
|
+
|
|
2374
|
+
# start downloading video in background
|
|
2375
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
2376
|
+
|
|
2377
|
+
progress_total = (n_frames + step - 1) // step
|
|
2378
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2379
|
+
|
|
2380
|
+
_upload_f = partial(
|
|
2381
|
+
self.upload_predictions_to_video,
|
|
2382
|
+
api=api,
|
|
2383
|
+
video_info=video_info,
|
|
2384
|
+
track_id=track_id,
|
|
2385
|
+
context=inference_request.context,
|
|
2386
|
+
progress_cb=inference_request.done,
|
|
2387
|
+
inference_request=inference_request,
|
|
2388
|
+
)
|
|
2389
|
+
|
|
2390
|
+
_range = (start_frame_index, start_frame_index + direction * n_frames)
|
|
2391
|
+
if _range[0] > _range[1]:
|
|
2392
|
+
_range = (_range[1], _range[0])
|
|
2393
|
+
|
|
2394
|
+
def _notify_f(predictions: List[Prediction]):
|
|
2395
|
+
logger.debug(
|
|
2396
|
+
"Notifying tracking progress...",
|
|
2397
|
+
extra={
|
|
2398
|
+
"track_id": track_id,
|
|
2399
|
+
"range": _range,
|
|
2400
|
+
"current": inference_request.progress.current,
|
|
2401
|
+
"total": inference_request.progress.total,
|
|
2402
|
+
},
|
|
2403
|
+
)
|
|
2404
|
+
stopped = self.api.video.notify_progress(
|
|
2405
|
+
track_id=track_id,
|
|
2406
|
+
video_id=video_info.id,
|
|
2407
|
+
frame_start=_range[0],
|
|
2408
|
+
frame_end=_range[1],
|
|
2409
|
+
current=inference_request.progress.current,
|
|
2410
|
+
total=inference_request.progress.total,
|
|
2411
|
+
)
|
|
2412
|
+
if stopped:
|
|
2413
|
+
inference_request.stop()
|
|
2414
|
+
logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
|
|
2415
|
+
|
|
2416
|
+
def _exception_handler(e: Exception):
|
|
2417
|
+
self.api.video.notify_tracking_error(
|
|
2418
|
+
track_id=track_id,
|
|
2419
|
+
error=str(type(e)),
|
|
2420
|
+
message=str(e),
|
|
2421
|
+
)
|
|
2422
|
+
raise e
|
|
2423
|
+
|
|
2424
|
+
with Uploader(
|
|
2425
|
+
upload_f=_upload_f,
|
|
2426
|
+
notify_f=_notify_f,
|
|
2427
|
+
exception_handler=_exception_handler,
|
|
2428
|
+
logger=logger,
|
|
2429
|
+
) as uploader:
|
|
2430
|
+
for batch in batched(
|
|
2431
|
+
range(
|
|
2432
|
+
start_frame_index, start_frame_index + direction * n_frames, direction * step
|
|
2433
|
+
),
|
|
2434
|
+
batch_size,
|
|
2435
|
+
):
|
|
2436
|
+
if inference_request.is_stopped():
|
|
2437
|
+
logger.debug(
|
|
2438
|
+
f"Cancelling inference video...",
|
|
2439
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2440
|
+
)
|
|
2441
|
+
break
|
|
2442
|
+
logger.debug(
|
|
2443
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
2444
|
+
)
|
|
2445
|
+
frames = self.cache.download_frames(
|
|
2446
|
+
api, video_info.id, batch, redownload_video=True
|
|
2447
|
+
)
|
|
2448
|
+
anns, slides_data = self._inference_auto(
|
|
2449
|
+
source=frames,
|
|
2450
|
+
settings=inference_settings,
|
|
2451
|
+
)
|
|
2452
|
+
|
|
2453
|
+
if inference_request.tracker is not None:
|
|
2454
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2455
|
+
|
|
2456
|
+
predictions = [
|
|
2457
|
+
Prediction(
|
|
2458
|
+
ann,
|
|
2459
|
+
model_meta=self.model_meta,
|
|
2460
|
+
frame_index=frame_index,
|
|
2461
|
+
video_id=video_info.id,
|
|
2462
|
+
dataset_id=video_info.dataset_id,
|
|
2463
|
+
project_id=video_info.project_id,
|
|
2464
|
+
)
|
|
2465
|
+
for ann, frame_index in zip(anns, batch)
|
|
2466
|
+
]
|
|
2467
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2468
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2469
|
+
uploader.put(predictions)
|
|
2470
|
+
video_ann_json = None
|
|
2471
|
+
if inference_request.tracker is not None:
|
|
2244
2472
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2245
|
-
video_ann_json =
|
|
2473
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2246
2474
|
inference_request.done()
|
|
2247
2475
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2248
2476
|
return video_ann_json
|
|
@@ -2268,10 +2496,9 @@ class Inference:
|
|
|
2268
2496
|
upload_mode = state.get("upload_mode", None)
|
|
2269
2497
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2270
2498
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2271
|
-
iou_merge_threshold =
|
|
2499
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
|
|
2272
2500
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
|
2273
2501
|
|
|
2274
|
-
project_info = api.project.get_info_by_id(project_id)
|
|
2275
2502
|
inference_request.context.setdefault("project_info", {})[project_id] = project_info
|
|
2276
2503
|
dataset_ids = state.get("dataset_ids", None)
|
|
2277
2504
|
if dataset_ids is None:
|
|
@@ -2306,7 +2533,11 @@ class Inference:
|
|
|
2306
2533
|
|
|
2307
2534
|
if cache_project_on_model:
|
|
2308
2535
|
download_to_cache(
|
|
2309
|
-
api,
|
|
2536
|
+
api,
|
|
2537
|
+
project_info.id,
|
|
2538
|
+
datasets_infos,
|
|
2539
|
+
progress_cb=inference_request.done,
|
|
2540
|
+
skip_create_readme=True,
|
|
2310
2541
|
)
|
|
2311
2542
|
|
|
2312
2543
|
images_infos_dict = {}
|
|
@@ -2315,20 +2546,9 @@ class Inference:
|
|
|
2315
2546
|
if not cache_project_on_model:
|
|
2316
2547
|
inference_request.done(dataset_info.items_count)
|
|
2317
2548
|
|
|
2318
|
-
def
|
|
2319
|
-
|
|
2320
|
-
|
|
2321
|
-
with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
|
|
2322
|
-
for image_id in image_ids:
|
|
2323
|
-
executor.submit(
|
|
2324
|
-
self.cache.download_image,
|
|
2325
|
-
api,
|
|
2326
|
-
image_id,
|
|
2327
|
-
)
|
|
2328
|
-
|
|
2329
|
-
if not cache_project_on_model:
|
|
2330
|
-
# start downloading in parallel
|
|
2331
|
-
threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
|
|
2549
|
+
def download_f(item: int):
|
|
2550
|
+
self.cache.download_image(api, item)
|
|
2551
|
+
return item
|
|
2332
2552
|
|
|
2333
2553
|
_upload_predictions = partial(
|
|
2334
2554
|
self.upload_predictions,
|
|
@@ -2343,7 +2563,9 @@ class Inference:
|
|
|
2343
2563
|
)
|
|
2344
2564
|
|
|
2345
2565
|
_add_results_to_request = partial(
|
|
2346
|
-
self.add_results_to_request,
|
|
2566
|
+
self.add_results_to_request,
|
|
2567
|
+
inference_request=inference_request,
|
|
2568
|
+
progress_cb=inference_request.done,
|
|
2347
2569
|
)
|
|
2348
2570
|
|
|
2349
2571
|
if upload_mode is None:
|
|
@@ -2351,57 +2573,78 @@ class Inference:
|
|
|
2351
2573
|
else:
|
|
2352
2574
|
upload_f = _upload_predictions
|
|
2353
2575
|
|
|
2576
|
+
download_workers = max(8, min(batch_size, 64))
|
|
2354
2577
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
|
|
2355
2578
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
2356
|
-
|
|
2357
|
-
for
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
|
|
2366
|
-
|
|
2367
|
-
|
|
2368
|
-
|
|
2369
|
-
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
project_info.id,
|
|
2373
|
-
dataset_info.name,
|
|
2374
|
-
[ii.name for ii in images_infos_batch],
|
|
2579
|
+
with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
|
|
2580
|
+
for images in images_infos_dict.values():
|
|
2581
|
+
for image in images:
|
|
2582
|
+
downloader.put(image.id)
|
|
2583
|
+
downloader.next(100)
|
|
2584
|
+
for dataset_info in datasets_infos:
|
|
2585
|
+
for images_infos_batch in batched(
|
|
2586
|
+
images_infos_dict[dataset_info.id], batch_size=batch_size
|
|
2587
|
+
):
|
|
2588
|
+
if uploader.has_exception():
|
|
2589
|
+
exception = uploader.exception
|
|
2590
|
+
raise exception
|
|
2591
|
+
if inference_request.is_stopped():
|
|
2592
|
+
logger.debug(
|
|
2593
|
+
f"Cancelling inference project...",
|
|
2594
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2375
2595
|
)
|
|
2596
|
+
return
|
|
2597
|
+
if inference_request.is_paused():
|
|
2598
|
+
logger.info("Inference request is paused. Waiting...")
|
|
2599
|
+
while inference_request.is_paused():
|
|
2600
|
+
if (
|
|
2601
|
+
inference_request.paused_for()
|
|
2602
|
+
> inference_request.PAUSE_SLEEP_MAX_WAIT
|
|
2603
|
+
):
|
|
2604
|
+
logger.info(
|
|
2605
|
+
"Inference request has been paused for too long. Cancelling..."
|
|
2606
|
+
)
|
|
2607
|
+
raise RuntimeError(
|
|
2608
|
+
"Inference request cancelled due to long pause."
|
|
2609
|
+
)
|
|
2610
|
+
time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
|
|
2611
|
+
if cache_project_on_model:
|
|
2612
|
+
images_paths, _ = zip(
|
|
2613
|
+
*read_from_cached_project(
|
|
2614
|
+
project_info.id,
|
|
2615
|
+
dataset_info.name,
|
|
2616
|
+
[ii.name for ii in images_infos_batch],
|
|
2617
|
+
)
|
|
2618
|
+
)
|
|
2619
|
+
images_nps = [sly_image.read(img_path) for img_path in images_paths]
|
|
2620
|
+
else:
|
|
2621
|
+
images_nps = self.cache.download_images(
|
|
2622
|
+
api,
|
|
2623
|
+
dataset_info.id,
|
|
2624
|
+
[info.id for info in images_infos_batch],
|
|
2625
|
+
return_images=True,
|
|
2626
|
+
)
|
|
2627
|
+
downloader.next(len(images_infos_batch))
|
|
2628
|
+
anns, slides_data = self._inference_auto(
|
|
2629
|
+
source=images_nps,
|
|
2630
|
+
settings=inference_settings,
|
|
2376
2631
|
)
|
|
2377
|
-
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2384
|
-
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2388
|
-
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
ann,
|
|
2392
|
-
model_meta=self.model_meta,
|
|
2393
|
-
image_id=image_info.id,
|
|
2394
|
-
name=image_info.name,
|
|
2395
|
-
dataset_id=dataset_info.id,
|
|
2396
|
-
project_id=dataset_info.project_id,
|
|
2397
|
-
image_name=image_info.name,
|
|
2398
|
-
)
|
|
2399
|
-
for ann, image_info in zip(anns, images_infos_batch)
|
|
2400
|
-
]
|
|
2401
|
-
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2402
|
-
pred.extra_data["slides_data"] = this_slides_data
|
|
2632
|
+
predictions = [
|
|
2633
|
+
Prediction(
|
|
2634
|
+
ann,
|
|
2635
|
+
model_meta=self.model_meta,
|
|
2636
|
+
image_id=image_info.id,
|
|
2637
|
+
name=image_info.name,
|
|
2638
|
+
dataset_id=dataset_info.id,
|
|
2639
|
+
project_id=dataset_info.project_id,
|
|
2640
|
+
image_name=image_info.name,
|
|
2641
|
+
)
|
|
2642
|
+
for ann, image_info in zip(anns, images_infos_batch)
|
|
2643
|
+
]
|
|
2644
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2645
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2403
2646
|
|
|
2404
|
-
|
|
2647
|
+
uploader.put(predictions)
|
|
2405
2648
|
|
|
2406
2649
|
def _run_speedtest(
|
|
2407
2650
|
self,
|
|
@@ -2444,7 +2687,13 @@ class Inference:
|
|
|
2444
2687
|
inference_request.done()
|
|
2445
2688
|
|
|
2446
2689
|
if cache_project_on_model:
|
|
2447
|
-
download_to_cache(
|
|
2690
|
+
download_to_cache(
|
|
2691
|
+
api,
|
|
2692
|
+
project_id,
|
|
2693
|
+
datasets_infos,
|
|
2694
|
+
progress_cb=inference_request.done,
|
|
2695
|
+
skip_create_readme=True,
|
|
2696
|
+
)
|
|
2448
2697
|
|
|
2449
2698
|
inference_request.set_stage("warmup", 0, num_warmup)
|
|
2450
2699
|
|
|
@@ -2565,6 +2814,11 @@ class Inference:
|
|
|
2565
2814
|
def _freeze_model(self):
|
|
2566
2815
|
if self._model_frozen or not self._model_served:
|
|
2567
2816
|
return
|
|
2817
|
+
|
|
2818
|
+
if not self._deploy_params:
|
|
2819
|
+
logger.warning("Deploy params are not set, cannot freeze the model.")
|
|
2820
|
+
return
|
|
2821
|
+
|
|
2568
2822
|
logger.debug("Freezing model...")
|
|
2569
2823
|
runtime = self._deploy_params.get("runtime")
|
|
2570
2824
|
if runtime and runtime.lower() != RuntimeType.PYTORCH.lower():
|
|
@@ -2907,11 +3161,89 @@ class Inference:
|
|
|
2907
3161
|
inference_request.add_results(results)
|
|
2908
3162
|
|
|
2909
3163
|
def add_results_to_request(
|
|
2910
|
-
self, predictions: List[Prediction], inference_request: InferenceRequest
|
|
3164
|
+
self, predictions: List[Prediction], inference_request: InferenceRequest, progress_cb=None
|
|
2911
3165
|
):
|
|
2912
3166
|
results = self._format_output(predictions)
|
|
2913
3167
|
inference_request.add_results(results)
|
|
2914
|
-
|
|
3168
|
+
if progress_cb:
|
|
3169
|
+
progress_cb(len(results))
|
|
3170
|
+
|
|
3171
|
+
def upload_predictions_to_video(
|
|
3172
|
+
self,
|
|
3173
|
+
predictions: List[Prediction],
|
|
3174
|
+
api: Api,
|
|
3175
|
+
video_info: VideoInfo,
|
|
3176
|
+
track_id: str,
|
|
3177
|
+
context: Dict,
|
|
3178
|
+
progress_cb=None,
|
|
3179
|
+
inference_request: InferenceRequest = None,
|
|
3180
|
+
):
|
|
3181
|
+
key_id_map = KeyIdMap()
|
|
3182
|
+
project_meta = context.get("project_meta", None)
|
|
3183
|
+
if project_meta is None:
|
|
3184
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
|
|
3185
|
+
context["project_meta"] = project_meta
|
|
3186
|
+
meta_changed = False
|
|
3187
|
+
for prediction in predictions:
|
|
3188
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3189
|
+
project_meta, prediction.annotation, None
|
|
3190
|
+
)
|
|
3191
|
+
prediction.annotation = ann
|
|
3192
|
+
meta_changed = meta_changed or meta_changed_
|
|
3193
|
+
if meta_changed:
|
|
3194
|
+
project_meta = api.project.update_meta(video_info.project_id, project_meta)
|
|
3195
|
+
context["project_meta"] = project_meta
|
|
3196
|
+
|
|
3197
|
+
figure_data_by_object_id = defaultdict(list)
|
|
3198
|
+
|
|
3199
|
+
tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
|
|
3200
|
+
new_tracks: Dict[int, VideoObject] = {}
|
|
3201
|
+
for prediction in predictions:
|
|
3202
|
+
annotation = prediction.annotation
|
|
3203
|
+
tracks = annotation.custom_data
|
|
3204
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3205
|
+
if track not in tracks_to_object_ids and track not in new_tracks:
|
|
3206
|
+
video_object = VideoObject(obj_class=label.obj_class)
|
|
3207
|
+
new_tracks[track] = video_object
|
|
3208
|
+
if new_tracks:
|
|
3209
|
+
tracks, video_objects = zip(*new_tracks.items())
|
|
3210
|
+
added_object_ids = api.video.object.append_bulk(
|
|
3211
|
+
video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
|
|
3212
|
+
)
|
|
3213
|
+
for track, object_id in zip(tracks, added_object_ids):
|
|
3214
|
+
tracks_to_object_ids[track] = object_id
|
|
3215
|
+
for prediction in predictions:
|
|
3216
|
+
annotation = prediction.annotation
|
|
3217
|
+
tracks = annotation.custom_data
|
|
3218
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3219
|
+
object_id = tracks_to_object_ids[track]
|
|
3220
|
+
figure_data_by_object_id[object_id].append(
|
|
3221
|
+
{
|
|
3222
|
+
ApiField.OBJECT_ID: object_id,
|
|
3223
|
+
ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
|
|
3224
|
+
ApiField.GEOMETRY: label.geometry.to_json(),
|
|
3225
|
+
ApiField.META: {ApiField.FRAME: prediction.frame_index},
|
|
3226
|
+
ApiField.TRACK_ID: track_id,
|
|
3227
|
+
}
|
|
3228
|
+
)
|
|
3229
|
+
|
|
3230
|
+
for object_id, figures_data in figure_data_by_object_id.items():
|
|
3231
|
+
figures_keys = [uuid.uuid4() for _ in figures_data]
|
|
3232
|
+
api.video.figure._append_bulk(
|
|
3233
|
+
entity_id=video_info.id,
|
|
3234
|
+
figures_json=figures_data,
|
|
3235
|
+
figures_keys=figures_keys,
|
|
3236
|
+
key_id_map=key_id_map,
|
|
3237
|
+
)
|
|
3238
|
+
logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
|
|
3239
|
+
if progress_cb:
|
|
3240
|
+
progress_cb(len(predictions))
|
|
3241
|
+
if inference_request is not None:
|
|
3242
|
+
results = self._format_output(predictions)
|
|
3243
|
+
for result in results:
|
|
3244
|
+
result["annotation"] = None
|
|
3245
|
+
result["data"] = None
|
|
3246
|
+
inference_request.add_results(results)
|
|
2915
3247
|
|
|
2916
3248
|
def serve(self):
|
|
2917
3249
|
if not self._use_gui and not self._is_cli_deploy:
|
|
@@ -2995,7 +3327,7 @@ class Inference:
|
|
|
2995
3327
|
|
|
2996
3328
|
if not self._use_gui:
|
|
2997
3329
|
Progress("Model deployed", 1).iter_done_report()
|
|
2998
|
-
|
|
3330
|
+
elif self.api is not None:
|
|
2999
3331
|
autostart_func()
|
|
3000
3332
|
|
|
3001
3333
|
@server.exception_handler(HTTPException)
|
|
@@ -3022,6 +3354,11 @@ class Inference:
|
|
|
3022
3354
|
def get_session_info(response: Response):
|
|
3023
3355
|
return self.get_info()
|
|
3024
3356
|
|
|
3357
|
+
@server.post("/get_tracking_settings")
|
|
3358
|
+
@self._check_serve_before_call
|
|
3359
|
+
def get_tracking_settings(response: Response):
|
|
3360
|
+
return self.get_tracking_settings()
|
|
3361
|
+
|
|
3025
3362
|
@server.post("/get_custom_inference_settings")
|
|
3026
3363
|
def get_custom_inference_settings():
|
|
3027
3364
|
return {"settings": self.custom_inference_settings}
|
|
@@ -3305,6 +3642,22 @@ class Inference:
|
|
|
3305
3642
|
"inference_request_uuid": inference_request.uuid,
|
|
3306
3643
|
}
|
|
3307
3644
|
|
|
3645
|
+
@server.post("/tracking_by_detection")
|
|
3646
|
+
def tracking_by_detection(response: Response, request: Request):
|
|
3647
|
+
state = request.state.state
|
|
3648
|
+
context = request.state.context
|
|
3649
|
+
state.update(context)
|
|
3650
|
+
if state.get("tracker") is None:
|
|
3651
|
+
state["tracker"] = "botsort"
|
|
3652
|
+
|
|
3653
|
+
logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
|
|
3654
|
+
self.validate_inference_state(state)
|
|
3655
|
+
api = self.api_from_request(request)
|
|
3656
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
3657
|
+
self._tracking_by_detection, api, state
|
|
3658
|
+
)
|
|
3659
|
+
return {"message": "Track task started."}
|
|
3660
|
+
|
|
3308
3661
|
@server.post("/inference_project_id_async")
|
|
3309
3662
|
def inference_project_id_async(response: Response, request: Request):
|
|
3310
3663
|
state = request.state.state
|
|
@@ -3368,10 +3721,7 @@ class Inference:
|
|
|
3368
3721
|
data = {**inference_request.to_json(), **log_extra}
|
|
3369
3722
|
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
3370
3723
|
data["progress"] = {"current": 0, "total": 1}
|
|
3371
|
-
logger.debug(
|
|
3372
|
-
f"Sending inference progress with uuid:",
|
|
3373
|
-
extra=data,
|
|
3374
|
-
)
|
|
3724
|
+
logger.debug(f"Sending inference progress with uuid:", extra=data)
|
|
3375
3725
|
return data
|
|
3376
3726
|
|
|
3377
3727
|
@server.post(f"/pop_inference_results")
|
|
@@ -4228,10 +4578,10 @@ class Inference:
|
|
|
4228
4578
|
self._args.draw,
|
|
4229
4579
|
)
|
|
4230
4580
|
|
|
4231
|
-
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
|
|
4581
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
|
|
4232
4582
|
updated_anns = []
|
|
4233
4583
|
for frame, ann in zip(frames, anns):
|
|
4234
|
-
matches =
|
|
4584
|
+
matches = tracker.update(frame, ann)
|
|
4235
4585
|
track_ids = [match["track_id"] for match in matches]
|
|
4236
4586
|
tracked_labels = [match["label"] for match in matches]
|
|
4237
4587
|
|
|
@@ -4297,61 +4647,72 @@ class Inference:
|
|
|
4297
4647
|
def export_tensorrt(self, deploy_params: dict):
|
|
4298
4648
|
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4299
4649
|
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
|
|
4303
|
-
dataset_id: int,
|
|
4304
|
-
gt_image_ids: List[int],
|
|
4305
|
-
iou: float = None,
|
|
4306
|
-
meta: Optional[ProjectMeta] = None,
|
|
4650
|
+
|
|
4651
|
+
def _filter_duplicated_predictions_from_ann_cpu(
|
|
4652
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
4307
4653
|
):
|
|
4308
4654
|
"""
|
|
4309
|
-
Filter out
|
|
4655
|
+
Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
|
|
4656
|
+
Uses Shapely for geometric operations.
|
|
4310
4657
|
|
|
4311
|
-
|
|
4312
|
-
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4658
|
+
Args:
|
|
4659
|
+
pred_ann: Predicted annotation object
|
|
4660
|
+
gt_ann: Ground truth annotation object
|
|
4661
|
+
iou_threshold: IoU threshold for filtering
|
|
4316
4662
|
|
|
4317
|
-
:
|
|
4318
|
-
|
|
4319
|
-
:param pred_anns: List of Annotation objects containing predictions
|
|
4320
|
-
:type pred_anns: List[Annotation]
|
|
4321
|
-
:param dataset_id: ID of the dataset containing the images
|
|
4322
|
-
:type dataset_id: int
|
|
4323
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4324
|
-
:type gt_image_ids: List[int]
|
|
4325
|
-
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4326
|
-
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4327
|
-
:type iou: Optional[float]
|
|
4328
|
-
:param meta: ProjectMeta object
|
|
4329
|
-
:type meta: Optional[ProjectMeta]
|
|
4330
|
-
:return: List of Annotation objects containing filtered predictions
|
|
4331
|
-
:rtype: List[Annotation]
|
|
4332
|
-
|
|
4333
|
-
Notes:
|
|
4334
|
-
------
|
|
4335
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
4336
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4663
|
+
Returns:
|
|
4664
|
+
New annotation with filtered labels
|
|
4337
4665
|
"""
|
|
4338
|
-
if
|
|
4339
|
-
|
|
4340
|
-
|
|
4341
|
-
|
|
4342
|
-
|
|
4343
|
-
|
|
4344
|
-
|
|
4345
|
-
|
|
4346
|
-
|
|
4347
|
-
|
|
4348
|
-
|
|
4349
|
-
|
|
4350
|
-
|
|
4351
|
-
|
|
4352
|
-
|
|
4353
|
-
|
|
4354
|
-
|
|
4666
|
+
if not iou_threshold:
|
|
4667
|
+
return pred_ann
|
|
4668
|
+
|
|
4669
|
+
from shapely.geometry import box
|
|
4670
|
+
|
|
4671
|
+
def calculate_iou(geom1: Geometry, geom2: Geometry):
|
|
4672
|
+
"""Calculate IoU between two geometries using Shapely."""
|
|
4673
|
+
bbox1 = geom1.to_bbox()
|
|
4674
|
+
bbox2 = geom2.to_bbox()
|
|
4675
|
+
|
|
4676
|
+
box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
|
|
4677
|
+
box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
|
|
4678
|
+
|
|
4679
|
+
intersection = box1.intersection(box2).area
|
|
4680
|
+
union = box1.union(box2).area
|
|
4681
|
+
|
|
4682
|
+
return intersection / union if union > 0 else 0.0
|
|
4683
|
+
|
|
4684
|
+
new_labels = []
|
|
4685
|
+
pred_cls_bboxes = defaultdict(list)
|
|
4686
|
+
for label in pred_ann.labels:
|
|
4687
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4688
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4689
|
+
|
|
4690
|
+
gt_cls_bboxes = defaultdict(list)
|
|
4691
|
+
for label in gt_ann.labels:
|
|
4692
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4693
|
+
if name_shape not in pred_cls_bboxes:
|
|
4694
|
+
continue
|
|
4695
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4696
|
+
|
|
4697
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4698
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4699
|
+
if len(gt) == 0:
|
|
4700
|
+
new_labels.extend(pred)
|
|
4701
|
+
continue
|
|
4702
|
+
|
|
4703
|
+
for pred_label in pred:
|
|
4704
|
+
# Check if this prediction has IoU < threshold with ALL GT boxes
|
|
4705
|
+
keep = True
|
|
4706
|
+
for gt_label in gt:
|
|
4707
|
+
iou = calculate_iou(pred_label.geometry, gt_label.geometry)
|
|
4708
|
+
if iou >= iou_threshold:
|
|
4709
|
+
keep = False
|
|
4710
|
+
break
|
|
4711
|
+
|
|
4712
|
+
if keep:
|
|
4713
|
+
new_labels.append(pred_label)
|
|
4714
|
+
|
|
4715
|
+
return pred_ann.clone(labels=new_labels)
|
|
4355
4716
|
|
|
4356
4717
|
|
|
4357
4718
|
def _filter_duplicated_predictions_from_ann(
|
|
@@ -4382,13 +4743,15 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4382
4743
|
- Predictions with classes not present in ground truth will be kept
|
|
4383
4744
|
- Requires PyTorch and torchvision for IoU calculations
|
|
4384
4745
|
"""
|
|
4746
|
+
if not iou_threshold:
|
|
4747
|
+
return pred_ann
|
|
4385
4748
|
|
|
4386
4749
|
try:
|
|
4387
4750
|
import torch
|
|
4388
4751
|
from torchvision.ops import box_iou
|
|
4389
4752
|
|
|
4390
4753
|
except ImportError:
|
|
4391
|
-
|
|
4754
|
+
return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
|
|
4392
4755
|
|
|
4393
4756
|
def _to_tensor(geom):
|
|
4394
4757
|
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
@@ -4396,16 +4759,18 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4396
4759
|
new_labels = []
|
|
4397
4760
|
pred_cls_bboxes = defaultdict(list)
|
|
4398
4761
|
for label in pred_ann.labels:
|
|
4399
|
-
|
|
4762
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4763
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4400
4764
|
|
|
4401
4765
|
gt_cls_bboxes = defaultdict(list)
|
|
4402
4766
|
for label in gt_ann.labels:
|
|
4403
|
-
|
|
4767
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4768
|
+
if name_shape not in pred_cls_bboxes:
|
|
4404
4769
|
continue
|
|
4405
|
-
gt_cls_bboxes[
|
|
4770
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4406
4771
|
|
|
4407
|
-
for
|
|
4408
|
-
gt = gt_cls_bboxes[
|
|
4772
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4773
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4409
4774
|
if len(gt) == 0:
|
|
4410
4775
|
new_labels.extend(pred)
|
|
4411
4776
|
continue
|
|
@@ -4419,6 +4784,63 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4419
4784
|
return pred_ann.clone(labels=new_labels)
|
|
4420
4785
|
|
|
4421
4786
|
|
|
4787
|
+
def _exclude_duplicated_predictions(
|
|
4788
|
+
api: Api,
|
|
4789
|
+
pred_anns: List[Annotation],
|
|
4790
|
+
dataset_id: int,
|
|
4791
|
+
gt_image_ids: List[int],
|
|
4792
|
+
iou: float = None,
|
|
4793
|
+
meta: Optional[ProjectMeta] = None,
|
|
4794
|
+
):
|
|
4795
|
+
"""
|
|
4796
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
4797
|
+
|
|
4798
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4799
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4800
|
+
- Gets ProjectMeta object if not provided
|
|
4801
|
+
- Downloads GT annotations for the specified image IDs
|
|
4802
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4803
|
+
|
|
4804
|
+
:param api: Supervisely API object
|
|
4805
|
+
:type api: Api
|
|
4806
|
+
:param pred_anns: List of Annotation objects containing predictions
|
|
4807
|
+
:type pred_anns: List[Annotation]
|
|
4808
|
+
:param dataset_id: ID of the dataset containing the images
|
|
4809
|
+
:type dataset_id: int
|
|
4810
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4811
|
+
:type gt_image_ids: List[int]
|
|
4812
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4813
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4814
|
+
:type iou: Optional[float]
|
|
4815
|
+
:param meta: ProjectMeta object
|
|
4816
|
+
:type meta: Optional[ProjectMeta]
|
|
4817
|
+
:return: List of Annotation objects containing filtered predictions
|
|
4818
|
+
:rtype: List[Annotation]
|
|
4819
|
+
|
|
4820
|
+
Notes:
|
|
4821
|
+
------
|
|
4822
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
4823
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4824
|
+
"""
|
|
4825
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
|
4826
|
+
if meta is None:
|
|
4827
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
|
4828
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
4829
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
4830
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
4831
|
+
for i in range(0, len(pred_anns)):
|
|
4832
|
+
before = len(pred_anns[i].labels)
|
|
4833
|
+
with Timer() as timer:
|
|
4834
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
|
4835
|
+
gt_anns[i], pred_anns[i], iou
|
|
4836
|
+
)
|
|
4837
|
+
after = len(pred_anns[i].labels)
|
|
4838
|
+
logger.debug(
|
|
4839
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
4840
|
+
)
|
|
4841
|
+
return pred_anns
|
|
4842
|
+
|
|
4843
|
+
|
|
4422
4844
|
def _get_log_extra_for_inference_request(
|
|
4423
4845
|
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
|
4424
4846
|
):
|
|
@@ -4526,7 +4948,7 @@ def get_gpu_count():
|
|
|
4526
4948
|
gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
|
|
4527
4949
|
return gpu_count
|
|
4528
4950
|
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
|
4529
|
-
logger.
|
|
4951
|
+
logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
|
|
4530
4952
|
return 0
|
|
4531
4953
|
|
|
4532
4954
|
|
|
@@ -4706,7 +5128,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
|
|
|
4706
5128
|
img_tags = None
|
|
4707
5129
|
if not any_label_updated:
|
|
4708
5130
|
labels = None
|
|
4709
|
-
ann = ann.clone(img_tags=
|
|
5131
|
+
ann = ann.clone(img_tags=img_tags)
|
|
5132
|
+
return meta, ann, meta_changed
|
|
5133
|
+
|
|
5134
|
+
|
|
5135
|
+
def update_meta_and_ann_for_video_annotation(
|
|
5136
|
+
meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
|
|
5137
|
+
):
|
|
5138
|
+
"""Update project meta and annotation to match each other
|
|
5139
|
+
If obj class or tag meta from annotation conflicts with project meta
|
|
5140
|
+
add suffix to obj class or tag meta.
|
|
5141
|
+
Return tuple of updated project meta, annotation and boolean flag if meta was changed.
|
|
5142
|
+
"""
|
|
5143
|
+
obj_classes_suffixes = ["_nn"]
|
|
5144
|
+
tag_meta_suffixes = ["_nn"]
|
|
5145
|
+
if model_prediction_suffix is not None:
|
|
5146
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
5147
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
5148
|
+
logger.debug(
|
|
5149
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
5150
|
+
)
|
|
5151
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
5152
|
+
meta_changed = False
|
|
5153
|
+
|
|
5154
|
+
# meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
5155
|
+
# if replaced_classes_in_meta:
|
|
5156
|
+
# meta_changed = True
|
|
5157
|
+
# logger.warning(
|
|
5158
|
+
# "Some classes names were fixed in project meta",
|
|
5159
|
+
# extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
5160
|
+
# )
|
|
5161
|
+
|
|
5162
|
+
new_objects: List[VideoObject] = []
|
|
5163
|
+
new_figures: List[VideoFigure] = []
|
|
5164
|
+
any_object_updated = False
|
|
5165
|
+
for video_object in ann.objects:
|
|
5166
|
+
this_object_figures = [
|
|
5167
|
+
figure for figure in ann.figures if figure.video_object.key() == video_object.key()
|
|
5168
|
+
]
|
|
5169
|
+
this_object_changed = False
|
|
5170
|
+
original_obj_class_name = video_object.obj_class.name
|
|
5171
|
+
suffix_found = False
|
|
5172
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
5173
|
+
obj_class = video_object.obj_class
|
|
5174
|
+
obj_class_name = obj_class.name + suffix
|
|
5175
|
+
if suffix:
|
|
5176
|
+
obj_class = obj_class.clone(name=obj_class_name)
|
|
5177
|
+
video_object = video_object.clone(obj_class=obj_class)
|
|
5178
|
+
any_object_updated = True
|
|
5179
|
+
this_object_changed = True
|
|
5180
|
+
meta_obj_class = meta.get_obj_class(obj_class_name)
|
|
5181
|
+
if meta_obj_class is None:
|
|
5182
|
+
# obj class is not in meta, add it with suffix
|
|
5183
|
+
meta = meta.add_obj_class(obj_class)
|
|
5184
|
+
new_objects.append(video_object)
|
|
5185
|
+
meta_changed = True
|
|
5186
|
+
suffix_found = True
|
|
5187
|
+
break
|
|
5188
|
+
elif (
|
|
5189
|
+
meta_obj_class.geometry_type.geometry_name()
|
|
5190
|
+
== video_object.obj_class.geometry_type.geometry_name()
|
|
5191
|
+
):
|
|
5192
|
+
# if object geometry is the same as in meta, use meta obj class
|
|
5193
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5194
|
+
new_objects.append(video_object)
|
|
5195
|
+
suffix_found = True
|
|
5196
|
+
any_object_updated = True
|
|
5197
|
+
this_object_changed = True
|
|
5198
|
+
break
|
|
5199
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
5200
|
+
# if meta obj class is AnyGeometry, use it in object
|
|
5201
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5202
|
+
new_objects.append(video_object)
|
|
5203
|
+
suffix_found = True
|
|
5204
|
+
any_object_updated = True
|
|
5205
|
+
this_object_changed = True
|
|
5206
|
+
break
|
|
5207
|
+
if not suffix_found:
|
|
5208
|
+
# if no suffix found, raise error
|
|
5209
|
+
raise ValueError(
|
|
5210
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
5211
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
5212
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
5213
|
+
)
|
|
5214
|
+
elif this_object_changed:
|
|
5215
|
+
this_object_figures = [
|
|
5216
|
+
figure.clone(video_object=video_object) for figure in this_object_figures
|
|
5217
|
+
]
|
|
5218
|
+
new_figures.extend(this_object_figures)
|
|
5219
|
+
if any_object_updated:
|
|
5220
|
+
frames_figures = {}
|
|
5221
|
+
for figure in new_figures:
|
|
5222
|
+
frames_figures.setdefault(figure.frame_index, []).append(figure)
|
|
5223
|
+
new_frames = FrameCollection(
|
|
5224
|
+
[
|
|
5225
|
+
Frame(index=frame_index, figures=figures)
|
|
5226
|
+
for frame_index, figures in frames_figures.items()
|
|
5227
|
+
]
|
|
5228
|
+
)
|
|
5229
|
+
ann = ann.clone(objects=new_objects, frames=new_frames)
|
|
5230
|
+
|
|
5231
|
+
# check if tag metas are in project meta
|
|
5232
|
+
# if not, add them with suffix
|
|
5233
|
+
ann_tag_metas: Dict[str, TagMeta] = {}
|
|
5234
|
+
for video_object in ann.objects:
|
|
5235
|
+
for tag in video_object.tags:
|
|
5236
|
+
tag_name = tag.meta.name
|
|
5237
|
+
if tag_name not in ann_tag_metas:
|
|
5238
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5239
|
+
for tag in ann.tags:
|
|
5240
|
+
tag_name = tag.meta.name
|
|
5241
|
+
if tag_name not in ann_tag_metas:
|
|
5242
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5243
|
+
|
|
5244
|
+
changed_tag_metas = {}
|
|
5245
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
5246
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
5247
|
+
if meta_tag_meta is None:
|
|
5248
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
5249
|
+
meta_changed = True
|
|
5250
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5251
|
+
suffix_found = False
|
|
5252
|
+
for suffix in tag_meta_suffixes:
|
|
5253
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
5254
|
+
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
5255
|
+
if meta_tag_meta is None:
|
|
5256
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
5257
|
+
meta = meta.add_tag_meta(new_tag_meta)
|
|
5258
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
5259
|
+
meta_changed = True
|
|
5260
|
+
suffix_found = True
|
|
5261
|
+
break
|
|
5262
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5263
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
5264
|
+
suffix_found = True
|
|
5265
|
+
break
|
|
5266
|
+
if not suffix_found:
|
|
5267
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
5268
|
+
|
|
5269
|
+
if changed_tag_metas:
|
|
5270
|
+
objects = []
|
|
5271
|
+
any_object_updated = False
|
|
5272
|
+
for video_object in ann.objects:
|
|
5273
|
+
any_tag_updated = False
|
|
5274
|
+
object_tags = []
|
|
5275
|
+
for tag in video_object.tags:
|
|
5276
|
+
if tag.meta.name in changed_tag_metas:
|
|
5277
|
+
object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5278
|
+
any_tag_updated = True
|
|
5279
|
+
else:
|
|
5280
|
+
object_tags.append(tag)
|
|
5281
|
+
if any_tag_updated:
|
|
5282
|
+
video_object = video_object.clone(tags=TagCollection(object_tags))
|
|
5283
|
+
any_object_updated = True
|
|
5284
|
+
objects.append(video_object)
|
|
5285
|
+
|
|
5286
|
+
video_tags = []
|
|
5287
|
+
any_tag_updated = False
|
|
5288
|
+
for tag in ann.tags:
|
|
5289
|
+
if tag.meta.name in changed_tag_metas:
|
|
5290
|
+
video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5291
|
+
any_tag_updated = True
|
|
5292
|
+
else:
|
|
5293
|
+
video_tags.append(tag)
|
|
5294
|
+
if any_tag_updated or any_object_updated:
|
|
5295
|
+
if any_tag_updated:
|
|
5296
|
+
video_tags = VideoTagCollection(video_tags)
|
|
5297
|
+
else:
|
|
5298
|
+
video_tags = None
|
|
5299
|
+
if any_object_updated:
|
|
5300
|
+
objects = VideoObjectCollection(objects)
|
|
5301
|
+
else:
|
|
5302
|
+
objects = None
|
|
5303
|
+
ann = ann.clone(tags=video_tags, objects=objects)
|
|
5304
|
+
|
|
4710
5305
|
return meta, ann, meta_changed
|
|
4711
5306
|
|
|
4712
5307
|
|
|
@@ -4820,7 +5415,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
4820
5415
|
return data[key]
|
|
4821
5416
|
return None
|
|
4822
5417
|
|
|
4823
|
-
|
|
5418
|
+
|
|
5419
|
+
def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
|
|
4824
5420
|
import torch # pylint: disable=import-error
|
|
4825
5421
|
|
|
4826
5422
|
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|