supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +137 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +14 -11
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +14 -11
- supervisely/api/api.py +59 -38
- supervisely/api/app_api.py +11 -2
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +52 -4
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +72 -1
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +9 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +24 -4
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +916 -222
- supervisely/nn/inference/inference_request.py +55 -10
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +10 -1
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction.py +2 -1
- supervisely/nn/model/prediction_session.py +26 -14
- supervisely/nn/prediction_dto.py +19 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/utils.py +4 -5
- supervisely/nn/tracker/visualize.py +93 -93
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +46 -31
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- supervisely_lib/__init__.py +6 -1
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -5,12 +5,14 @@ import asyncio
|
|
|
5
5
|
import inspect
|
|
6
6
|
import json
|
|
7
7
|
import os
|
|
8
|
+
import queue
|
|
8
9
|
import re
|
|
9
10
|
import shutil
|
|
10
11
|
import subprocess
|
|
11
12
|
import tempfile
|
|
12
13
|
import threading
|
|
13
14
|
import time
|
|
15
|
+
import uuid
|
|
14
16
|
from collections import OrderedDict, defaultdict
|
|
15
17
|
from concurrent.futures import ThreadPoolExecutor
|
|
16
18
|
from dataclasses import asdict, dataclass
|
|
@@ -45,13 +47,14 @@ from supervisely._utils import (
|
|
|
45
47
|
rand_str,
|
|
46
48
|
)
|
|
47
49
|
from supervisely.annotation.annotation import Annotation
|
|
48
|
-
from supervisely.annotation.label import Label
|
|
50
|
+
from supervisely.annotation.label import Label, LabelingStatus
|
|
49
51
|
from supervisely.annotation.obj_class import ObjClass
|
|
50
52
|
from supervisely.annotation.tag_collection import TagCollection
|
|
51
53
|
from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
52
54
|
from supervisely.api.api import Api, ApiField
|
|
53
55
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
54
56
|
from supervisely.api.image_api import ImageInfo
|
|
57
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
55
58
|
from supervisely.app.content import get_data_dir
|
|
56
59
|
from supervisely.app.fastapi.subapp import (
|
|
57
60
|
Application,
|
|
@@ -67,6 +70,7 @@ from supervisely.decorators.inference import (
|
|
|
67
70
|
process_images_batch_sliding_window,
|
|
68
71
|
)
|
|
69
72
|
from supervisely.geometry.any_geometry import AnyGeometry
|
|
73
|
+
from supervisely.geometry.geometry import Geometry
|
|
70
74
|
from supervisely.imaging.color import get_predefined_colors
|
|
71
75
|
from supervisely.io.fs import list_files
|
|
72
76
|
from supervisely.nn.experiments import ExperimentInfo
|
|
@@ -75,7 +79,7 @@ from supervisely.nn.inference.inference_request import (
|
|
|
75
79
|
InferenceRequest,
|
|
76
80
|
InferenceRequestsManager,
|
|
77
81
|
)
|
|
78
|
-
from supervisely.nn.inference.uploader import Uploader
|
|
82
|
+
from supervisely.nn.inference.uploader import Downloader, Uploader
|
|
79
83
|
from supervisely.nn.model.model_api import ModelAPI, Prediction
|
|
80
84
|
from supervisely.nn.prediction_dto import Prediction as PredictionDTO
|
|
81
85
|
from supervisely.nn.utils import (
|
|
@@ -94,6 +98,17 @@ from supervisely.project.project_meta import ProjectMeta
|
|
|
94
98
|
from supervisely.sly_logger import logger
|
|
95
99
|
from supervisely.task.progress import Progress
|
|
96
100
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
101
|
+
from supervisely.video_annotation.frame import Frame
|
|
102
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
|
103
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
104
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
105
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
|
106
|
+
from supervisely.video_annotation.video_object import VideoObject
|
|
107
|
+
from supervisely.video_annotation.video_object_collection import (
|
|
108
|
+
VideoObject,
|
|
109
|
+
VideoObjectCollection,
|
|
110
|
+
)
|
|
111
|
+
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
|
97
112
|
|
|
98
113
|
try:
|
|
99
114
|
from typing import Literal
|
|
@@ -140,6 +155,7 @@ class Inference:
|
|
|
140
155
|
"""Default batch size for inference"""
|
|
141
156
|
INFERENCE_SETTINGS: str = None
|
|
142
157
|
"""Path to file with custom inference settings"""
|
|
158
|
+
DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
|
|
143
159
|
|
|
144
160
|
def __init__(
|
|
145
161
|
self,
|
|
@@ -193,7 +209,6 @@ class Inference:
|
|
|
193
209
|
self._task_id = None
|
|
194
210
|
self._sliding_window_mode = sliding_window_mode
|
|
195
211
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
196
|
-
self._tracker = None
|
|
197
212
|
self._hardware: str = None
|
|
198
213
|
if custom_inference_settings is None:
|
|
199
214
|
if self.INFERENCE_SETTINGS is not None:
|
|
@@ -427,7 +442,7 @@ class Inference:
|
|
|
427
442
|
|
|
428
443
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
429
444
|
except Exception as e:
|
|
430
|
-
logger.
|
|
445
|
+
logger.warning(
|
|
431
446
|
f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
|
|
432
447
|
)
|
|
433
448
|
device = "cpu"
|
|
@@ -734,15 +749,15 @@ class Inference:
|
|
|
734
749
|
for model in self.pretrained_models:
|
|
735
750
|
model_meta = model.get("meta")
|
|
736
751
|
if model_meta is not None:
|
|
737
|
-
|
|
738
|
-
if
|
|
739
|
-
if
|
|
752
|
+
this_model_name = model_meta.get("model_name")
|
|
753
|
+
if this_model_name is not None:
|
|
754
|
+
if this_model_name.lower() == model_name.lower():
|
|
740
755
|
selected_model = model
|
|
741
756
|
break
|
|
742
757
|
else:
|
|
743
|
-
|
|
744
|
-
if
|
|
745
|
-
if
|
|
758
|
+
this_model_name = model.get("model_name")
|
|
759
|
+
if this_model_name is not None:
|
|
760
|
+
if this_model_name.lower() == model_name.lower():
|
|
746
761
|
selected_model = model
|
|
747
762
|
break
|
|
748
763
|
|
|
@@ -863,6 +878,50 @@ class Inference:
|
|
|
863
878
|
self.gui.download_progress.hide()
|
|
864
879
|
return local_model_files
|
|
865
880
|
|
|
881
|
+
def _fallback_download_custom_model_pt(self, deploy_params: dict):
|
|
882
|
+
"""
|
|
883
|
+
Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
|
|
884
|
+
"""
|
|
885
|
+
team_id = sly_env.team_id()
|
|
886
|
+
|
|
887
|
+
checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
|
|
888
|
+
artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
|
|
889
|
+
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
|
|
890
|
+
checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
|
|
891
|
+
|
|
892
|
+
pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
|
|
893
|
+
remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
|
|
894
|
+
local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
|
|
895
|
+
|
|
896
|
+
file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
|
|
897
|
+
file_size = file_info.sizeb
|
|
898
|
+
if self.gui is not None:
|
|
899
|
+
with self.gui.download_progress(
|
|
900
|
+
message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
|
|
901
|
+
total=file_size,
|
|
902
|
+
unit="bytes",
|
|
903
|
+
unit_scale=True,
|
|
904
|
+
) as download_pbar:
|
|
905
|
+
self.gui.download_progress.show()
|
|
906
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
|
|
907
|
+
self.gui.download_progress.hide()
|
|
908
|
+
else:
|
|
909
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
|
|
910
|
+
|
|
911
|
+
return local_checkpoint_path
|
|
912
|
+
|
|
913
|
+
def _remove_exported_checkpoints(self, checkpoint_path: str):
|
|
914
|
+
"""
|
|
915
|
+
Removes the exported checkpoints for provided PyTorch checkpoint path.
|
|
916
|
+
"""
|
|
917
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
918
|
+
onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
|
|
919
|
+
engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
|
|
920
|
+
if os.path.exists(onnx_path):
|
|
921
|
+
sly_fs.silent_remove(onnx_path)
|
|
922
|
+
if os.path.exists(engine_path):
|
|
923
|
+
sly_fs.silent_remove(engine_path)
|
|
924
|
+
|
|
866
925
|
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
867
926
|
"""
|
|
868
927
|
Downloads the custom model data.
|
|
@@ -1060,7 +1119,41 @@ class Inference:
|
|
|
1060
1119
|
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
|
|
1061
1120
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
1062
1121
|
self._hardware = get_hardware_info(self.device)
|
|
1063
|
-
|
|
1122
|
+
|
|
1123
|
+
model_files = deploy_params.get("model_files", None)
|
|
1124
|
+
if model_files is not None:
|
|
1125
|
+
checkpoint_path = deploy_params["model_files"]["checkpoint"]
|
|
1126
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1127
|
+
if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
|
|
1128
|
+
try:
|
|
1129
|
+
self.load_model(**deploy_params)
|
|
1130
|
+
except Exception as e:
|
|
1131
|
+
logger.warning(
|
|
1132
|
+
f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
|
|
1133
|
+
)
|
|
1134
|
+
checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
|
|
1135
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1136
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1137
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1138
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1139
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1140
|
+
self.load_model(**deploy_params)
|
|
1141
|
+
if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
|
|
1142
|
+
if self.runtime == RuntimeType.ONNXRUNTIME:
|
|
1143
|
+
logger.info("Exporting PyTorch model to ONNX...")
|
|
1144
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1145
|
+
checkpoint_path = self.export_onnx(deploy_params)
|
|
1146
|
+
elif self.runtime == RuntimeType.TENSORRT:
|
|
1147
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1148
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1149
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1150
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1151
|
+
self.load_model(**deploy_params)
|
|
1152
|
+
else:
|
|
1153
|
+
self.load_model(**deploy_params)
|
|
1154
|
+
else:
|
|
1155
|
+
self.load_model(**deploy_params)
|
|
1156
|
+
|
|
1064
1157
|
self._model_served = True
|
|
1065
1158
|
self._deploy_params = deploy_params
|
|
1066
1159
|
if self._task_id is not None and is_production():
|
|
@@ -1269,18 +1362,19 @@ class Inference:
|
|
|
1269
1362
|
|
|
1270
1363
|
def get_classes(self) -> List[str]:
|
|
1271
1364
|
return self.classes
|
|
1272
|
-
|
|
1365
|
+
|
|
1273
1366
|
def _tracker_init(self, tracker: str, tracker_settings: dict):
|
|
1274
1367
|
# Check if tracking is supported for this model
|
|
1275
1368
|
info = self.get_info()
|
|
1276
1369
|
tracking_support = info.get("tracking_on_videos_support", False)
|
|
1277
|
-
|
|
1370
|
+
|
|
1278
1371
|
if not tracking_support:
|
|
1279
1372
|
logger.debug("Tracking is not supported for this model")
|
|
1280
1373
|
return None
|
|
1281
|
-
|
|
1374
|
+
|
|
1282
1375
|
if tracker == "botsort":
|
|
1283
1376
|
from supervisely.nn.tracker import BotSortTracker
|
|
1377
|
+
|
|
1284
1378
|
device = tracker_settings.get("device", self.device)
|
|
1285
1379
|
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1286
1380
|
return BotSortTracker(settings=tracker_settings, device=device)
|
|
@@ -1289,7 +1383,6 @@ class Inference:
|
|
|
1289
1383
|
logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
|
|
1290
1384
|
return None
|
|
1291
1385
|
|
|
1292
|
-
|
|
1293
1386
|
def get_info(self) -> Dict[str, Any]:
|
|
1294
1387
|
num_classes = None
|
|
1295
1388
|
classes = None
|
|
@@ -1298,15 +1391,15 @@ class Inference:
|
|
|
1298
1391
|
if classes is not None:
|
|
1299
1392
|
num_classes = len(classes)
|
|
1300
1393
|
except NotImplementedError:
|
|
1301
|
-
logger.
|
|
1394
|
+
logger.warning(f"get_classes() function not implemented for {type(self)} object.")
|
|
1302
1395
|
except AttributeError:
|
|
1303
|
-
logger.
|
|
1396
|
+
logger.warning("Probably, get_classes() function not working without model deploy.")
|
|
1304
1397
|
except Exception as exc:
|
|
1305
|
-
logger.
|
|
1398
|
+
logger.warning("Unknown exception. Please, contact support")
|
|
1306
1399
|
logger.exception(exc)
|
|
1307
1400
|
|
|
1308
1401
|
if num_classes is None:
|
|
1309
|
-
logger.
|
|
1402
|
+
logger.warning(f"get_classes() function return {classes}; skip classes processing.")
|
|
1310
1403
|
|
|
1311
1404
|
return {
|
|
1312
1405
|
"app_name": get_name_from_env(default="Neural Network Serving"),
|
|
@@ -1324,6 +1417,42 @@ class Inference:
|
|
|
1324
1417
|
|
|
1325
1418
|
# pylint: enable=method-hidden
|
|
1326
1419
|
|
|
1420
|
+
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
|
1421
|
+
"""
|
|
1422
|
+
Get default parameters for all available tracking algorithms.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
{"botsort": {"track_high_thresh": 0.6, ...}}
|
|
1426
|
+
Empty dict if tracking not supported.
|
|
1427
|
+
"""
|
|
1428
|
+
info = self.get_info()
|
|
1429
|
+
trackers_params = {}
|
|
1430
|
+
|
|
1431
|
+
tracking_support = info.get("tracking_on_videos_support")
|
|
1432
|
+
if not tracking_support:
|
|
1433
|
+
return trackers_params
|
|
1434
|
+
|
|
1435
|
+
tracking_algorithms = info.get("tracking_algorithms", [])
|
|
1436
|
+
|
|
1437
|
+
for tracker_name in tracking_algorithms:
|
|
1438
|
+
try:
|
|
1439
|
+
if tracker_name == "botsort":
|
|
1440
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1441
|
+
|
|
1442
|
+
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
|
1443
|
+
# Add other trackers here as elif blocks
|
|
1444
|
+
else:
|
|
1445
|
+
logger.debug(f"Tracker '{tracker_name}' not implemented")
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
logger.warning(f"Failed to get params for '{tracker_name}': {e}")
|
|
1448
|
+
|
|
1449
|
+
INTERNAL_FIELDS = {"device", "fps"}
|
|
1450
|
+
for tracker_name, params in trackers_params.items():
|
|
1451
|
+
trackers_params[tracker_name] = {
|
|
1452
|
+
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
|
1453
|
+
}
|
|
1454
|
+
return trackers_params
|
|
1455
|
+
|
|
1327
1456
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
|
1328
1457
|
hr_info = {}
|
|
1329
1458
|
info = self.get_info()
|
|
@@ -1439,8 +1568,12 @@ class Inference:
|
|
|
1439
1568
|
# for example empty mask
|
|
1440
1569
|
continue
|
|
1441
1570
|
if isinstance(label, list):
|
|
1571
|
+
for lb in label:
|
|
1572
|
+
lb.status = LabelingStatus.AUTO
|
|
1442
1573
|
labels.extend(label)
|
|
1443
1574
|
continue
|
|
1575
|
+
|
|
1576
|
+
label.status = LabelingStatus.AUTO
|
|
1444
1577
|
labels.append(label)
|
|
1445
1578
|
|
|
1446
1579
|
# create annotation with correct image resolution
|
|
@@ -1871,8 +2004,8 @@ class Inference:
|
|
|
1871
2004
|
else:
|
|
1872
2005
|
n_frames = frames_reader.frames_count()
|
|
1873
2006
|
|
|
1874
|
-
|
|
1875
|
-
|
|
2007
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2008
|
+
|
|
1876
2009
|
progress_total = (n_frames + step - 1) // step
|
|
1877
2010
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
1878
2011
|
|
|
@@ -1896,32 +2029,30 @@ class Inference:
|
|
|
1896
2029
|
source=frames,
|
|
1897
2030
|
settings=inference_settings,
|
|
1898
2031
|
)
|
|
1899
|
-
|
|
1900
|
-
if
|
|
1901
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
1902
|
-
|
|
2032
|
+
|
|
2033
|
+
if inference_request.tracker is not None:
|
|
2034
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2035
|
+
|
|
1903
2036
|
predictions = [
|
|
1904
2037
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
1905
2038
|
for ann, frame_index in zip(anns, batch)
|
|
1906
2039
|
]
|
|
1907
|
-
|
|
2040
|
+
|
|
1908
2041
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1909
2042
|
pred.extra_data["slides_data"] = this_slides_data
|
|
1910
2043
|
batch_results = self._format_output(predictions)
|
|
1911
|
-
|
|
2044
|
+
|
|
1912
2045
|
inference_request.add_results(batch_results)
|
|
1913
2046
|
inference_request.done(len(batch_results))
|
|
1914
2047
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1915
2048
|
video_ann_json = None
|
|
1916
|
-
if
|
|
2049
|
+
if inference_request.tracker is not None:
|
|
1917
2050
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1918
|
-
|
|
1919
|
-
video_ann_json = self._tracker.video_annotation.to_json()
|
|
2051
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
1920
2052
|
inference_request.done()
|
|
1921
2053
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
1922
2054
|
inference_request.final_result = result.copy()
|
|
1923
2055
|
return video_ann_json
|
|
1924
|
-
|
|
1925
2056
|
|
|
1926
2057
|
def _inference_image_ids(
|
|
1927
2058
|
self,
|
|
@@ -1949,7 +2080,7 @@ class Inference:
|
|
|
1949
2080
|
upload_mode = state.get("upload_mode", None)
|
|
1950
2081
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
1951
2082
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
1952
|
-
iou_merge_threshold = 0.
|
|
2083
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
|
|
1953
2084
|
|
|
1954
2085
|
images_infos = api.image.get_info_by_id_batch(image_ids)
|
|
1955
2086
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
|
@@ -1991,14 +2122,9 @@ class Inference:
|
|
|
1991
2122
|
output_dataset_id
|
|
1992
2123
|
] = output_dataset_info
|
|
1993
2124
|
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
dataset_image_infos[image_info.dataset_id].append(image_info)
|
|
1998
|
-
for dataset_id, ds_image_infos in dataset_image_infos.items():
|
|
1999
|
-
self.cache.run_cache_task_manually(
|
|
2000
|
-
api, [info.id for info in ds_image_infos], dataset_id=dataset_id
|
|
2001
|
-
)
|
|
2125
|
+
def download_f(item: int):
|
|
2126
|
+
self.cache.download_image(api, item)
|
|
2127
|
+
return item
|
|
2002
2128
|
|
|
2003
2129
|
_upload_predictions = partial(
|
|
2004
2130
|
self.upload_predictions,
|
|
@@ -2014,7 +2140,9 @@ class Inference:
|
|
|
2014
2140
|
)
|
|
2015
2141
|
|
|
2016
2142
|
_add_results_to_request = partial(
|
|
2017
|
-
self.add_results_to_request,
|
|
2143
|
+
self.add_results_to_request,
|
|
2144
|
+
inference_request=inference_request,
|
|
2145
|
+
progress_cb=inference_request.done,
|
|
2018
2146
|
)
|
|
2019
2147
|
|
|
2020
2148
|
if upload_mode is None:
|
|
@@ -2023,40 +2151,60 @@ class Inference:
|
|
|
2023
2151
|
upload_f = _upload_predictions
|
|
2024
2152
|
|
|
2025
2153
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, len(image_ids))
|
|
2154
|
+
download_workers = max(8, min(batch_size, 64))
|
|
2026
2155
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
)
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2156
|
+
with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
|
|
2157
|
+
for image_id in image_ids:
|
|
2158
|
+
downloader.put(image_id)
|
|
2159
|
+
downloader.next(100)
|
|
2160
|
+
for image_ids_batch in batched(image_ids, batch_size=batch_size):
|
|
2161
|
+
if uploader.has_exception():
|
|
2162
|
+
exception = uploader.exception
|
|
2163
|
+
raise exception
|
|
2164
|
+
if inference_request.is_stopped():
|
|
2165
|
+
logger.debug(
|
|
2166
|
+
f"Cancelling inference...",
|
|
2167
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2168
|
+
)
|
|
2169
|
+
break
|
|
2170
|
+
if inference_request.is_paused():
|
|
2171
|
+
logger.info("Inference request is paused. Waiting...")
|
|
2172
|
+
while inference_request.is_paused():
|
|
2173
|
+
if (
|
|
2174
|
+
inference_request.paused_for()
|
|
2175
|
+
> inference_request.PAUSE_SLEEP_MAX_WAIT
|
|
2176
|
+
):
|
|
2177
|
+
logger.info(
|
|
2178
|
+
"Inference request has been paused for too long. Cancelling..."
|
|
2179
|
+
)
|
|
2180
|
+
raise RuntimeError("Inference request cancelled due to long pause.")
|
|
2181
|
+
time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
|
|
2043
2182
|
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
name=image_info.name,
|
|
2052
|
-
image_id=image_info.id,
|
|
2053
|
-
dataset_id=image_info.dataset_id,
|
|
2054
|
-
project_id=dataset_info.project_id,
|
|
2183
|
+
images_nps = [
|
|
2184
|
+
self.cache.download_image(api, img_id) for img_id in image_ids_batch
|
|
2185
|
+
]
|
|
2186
|
+
downloader.next(len(image_ids_batch))
|
|
2187
|
+
anns, slides_data = self._inference_auto(
|
|
2188
|
+
source=images_nps,
|
|
2189
|
+
settings=inference_settings,
|
|
2055
2190
|
)
|
|
2056
|
-
prediction.extra_data["slides_data"] = this_slides_data
|
|
2057
|
-
batch_predictions.append(prediction)
|
|
2058
2191
|
|
|
2059
|
-
|
|
2192
|
+
batch_predictions = []
|
|
2193
|
+
for image_id, ann, this_slides_data in zip(image_ids_batch, anns, slides_data):
|
|
2194
|
+
image_info: ImageInfo = images_infos_dict[image_id]
|
|
2195
|
+
dataset_info = dataset_infos_dict[image_info.dataset_id]
|
|
2196
|
+
prediction = Prediction(
|
|
2197
|
+
ann,
|
|
2198
|
+
model_meta=self.model_meta,
|
|
2199
|
+
name=image_info.name,
|
|
2200
|
+
image_id=image_info.id,
|
|
2201
|
+
dataset_id=image_info.dataset_id,
|
|
2202
|
+
project_id=dataset_info.project_id,
|
|
2203
|
+
)
|
|
2204
|
+
prediction.extra_data["slides_data"] = this_slides_data
|
|
2205
|
+
batch_predictions.append(prediction)
|
|
2206
|
+
|
|
2207
|
+
uploader.put(batch_predictions)
|
|
2060
2208
|
|
|
2061
2209
|
def _inference_video_id(
|
|
2062
2210
|
self,
|
|
@@ -2071,7 +2219,7 @@ class Inference:
|
|
|
2071
2219
|
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2072
2220
|
if video_id is None:
|
|
2073
2221
|
raise ValueError("Video id is not provided")
|
|
2074
|
-
video_info = api.video.get_info_by_id(video_id)
|
|
2222
|
+
video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
|
|
2075
2223
|
start_frame_index = get_value_for_keys(
|
|
2076
2224
|
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2077
2225
|
)
|
|
@@ -2101,8 +2249,8 @@ class Inference:
|
|
|
2101
2249
|
else:
|
|
2102
2250
|
n_frames = video_info.frames_count
|
|
2103
2251
|
|
|
2104
|
-
|
|
2105
|
-
|
|
2252
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2253
|
+
|
|
2106
2254
|
logger.debug(
|
|
2107
2255
|
f"Video info:",
|
|
2108
2256
|
extra=dict(
|
|
@@ -2137,10 +2285,10 @@ class Inference:
|
|
|
2137
2285
|
source=frames,
|
|
2138
2286
|
settings=inference_settings,
|
|
2139
2287
|
)
|
|
2140
|
-
|
|
2141
|
-
if
|
|
2142
|
-
anns = self._apply_tracker_to_anns(frames, anns)
|
|
2143
|
-
|
|
2288
|
+
|
|
2289
|
+
if inference_request.tracker is not None:
|
|
2290
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2291
|
+
|
|
2144
2292
|
predictions = [
|
|
2145
2293
|
Prediction(
|
|
2146
2294
|
ann,
|
|
@@ -2148,21 +2296,181 @@ class Inference:
|
|
|
2148
2296
|
frame_index=frame_index,
|
|
2149
2297
|
video_id=video_info.id,
|
|
2150
2298
|
dataset_id=video_info.dataset_id,
|
|
2151
|
-
|
|
2152
|
-
|
|
2299
|
+
project_id=video_info.project_id,
|
|
2300
|
+
)
|
|
2153
2301
|
for ann, frame_index in zip(anns, batch)
|
|
2154
2302
|
]
|
|
2155
2303
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2156
2304
|
pred.extra_data["slides_data"] = this_slides_data
|
|
2157
2305
|
batch_results = self._format_output(predictions)
|
|
2158
|
-
|
|
2306
|
+
|
|
2159
2307
|
inference_request.add_results(batch_results)
|
|
2160
2308
|
inference_request.done(len(batch_results))
|
|
2161
2309
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
2162
2310
|
video_ann_json = None
|
|
2163
|
-
if
|
|
2311
|
+
if inference_request.tracker is not None:
|
|
2312
|
+
inference_request.set_stage("Postprocess...", 0, progress_total)
|
|
2313
|
+
|
|
2314
|
+
video_ann_json = inference_request.tracker.create_video_annotation(
|
|
2315
|
+
video_info.frames_count,
|
|
2316
|
+
start_frame_index,
|
|
2317
|
+
step=step,
|
|
2318
|
+
progress_cb=inference_request.done,
|
|
2319
|
+
).to_json()
|
|
2320
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2321
|
+
return video_ann_json
|
|
2322
|
+
|
|
2323
|
+
def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2324
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
2325
|
+
inference_settings = self._get_inference_settings(state)
|
|
2326
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
2327
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
2328
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2329
|
+
if video_id is None:
|
|
2330
|
+
raise ValueError("Video id is not provided")
|
|
2331
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
2332
|
+
start_frame_index = get_value_for_keys(
|
|
2333
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2334
|
+
)
|
|
2335
|
+
if start_frame_index is None:
|
|
2336
|
+
start_frame_index = 0
|
|
2337
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
2338
|
+
if step is None:
|
|
2339
|
+
step = 1
|
|
2340
|
+
end_frame_index = get_value_for_keys(
|
|
2341
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
2342
|
+
)
|
|
2343
|
+
duration = state.get("duration", None)
|
|
2344
|
+
frames_count = get_value_for_keys(
|
|
2345
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
2346
|
+
)
|
|
2347
|
+
tracking = state.get("tracker", None)
|
|
2348
|
+
direction = state.get("direction", "forward")
|
|
2349
|
+
direction = 1 if direction == "forward" else -1
|
|
2350
|
+
track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
|
|
2351
|
+
|
|
2352
|
+
if frames_count is not None:
|
|
2353
|
+
n_frames = frames_count
|
|
2354
|
+
elif end_frame_index is not None:
|
|
2355
|
+
n_frames = end_frame_index - start_frame_index
|
|
2356
|
+
elif duration is not None:
|
|
2357
|
+
fps = video_info.frames_count / video_info.duration
|
|
2358
|
+
n_frames = int(duration * fps)
|
|
2359
|
+
else:
|
|
2360
|
+
n_frames = video_info.frames_count
|
|
2361
|
+
|
|
2362
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2363
|
+
|
|
2364
|
+
logger.debug(
|
|
2365
|
+
f"Video info:",
|
|
2366
|
+
extra=dict(
|
|
2367
|
+
w=video_info.frame_width,
|
|
2368
|
+
h=video_info.frame_height,
|
|
2369
|
+
start_frame_index=start_frame_index,
|
|
2370
|
+
n_frames=n_frames,
|
|
2371
|
+
),
|
|
2372
|
+
)
|
|
2373
|
+
|
|
2374
|
+
# start downloading video in background
|
|
2375
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
2376
|
+
|
|
2377
|
+
progress_total = (n_frames + step - 1) // step
|
|
2378
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2379
|
+
|
|
2380
|
+
_upload_f = partial(
|
|
2381
|
+
self.upload_predictions_to_video,
|
|
2382
|
+
api=api,
|
|
2383
|
+
video_info=video_info,
|
|
2384
|
+
track_id=track_id,
|
|
2385
|
+
context=inference_request.context,
|
|
2386
|
+
progress_cb=inference_request.done,
|
|
2387
|
+
inference_request=inference_request,
|
|
2388
|
+
)
|
|
2389
|
+
|
|
2390
|
+
_range = (start_frame_index, start_frame_index + direction * n_frames)
|
|
2391
|
+
if _range[0] > _range[1]:
|
|
2392
|
+
_range = (_range[1], _range[0])
|
|
2393
|
+
|
|
2394
|
+
def _notify_f(predictions: List[Prediction]):
|
|
2395
|
+
logger.debug(
|
|
2396
|
+
"Notifying tracking progress...",
|
|
2397
|
+
extra={
|
|
2398
|
+
"track_id": track_id,
|
|
2399
|
+
"range": _range,
|
|
2400
|
+
"current": inference_request.progress.current,
|
|
2401
|
+
"total": inference_request.progress.total,
|
|
2402
|
+
},
|
|
2403
|
+
)
|
|
2404
|
+
stopped = self.api.video.notify_progress(
|
|
2405
|
+
track_id=track_id,
|
|
2406
|
+
video_id=video_info.id,
|
|
2407
|
+
frame_start=_range[0],
|
|
2408
|
+
frame_end=_range[1],
|
|
2409
|
+
current=inference_request.progress.current,
|
|
2410
|
+
total=inference_request.progress.total,
|
|
2411
|
+
)
|
|
2412
|
+
if stopped:
|
|
2413
|
+
inference_request.stop()
|
|
2414
|
+
logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
|
|
2415
|
+
|
|
2416
|
+
def _exception_handler(e: Exception):
|
|
2417
|
+
self.api.video.notify_tracking_error(
|
|
2418
|
+
track_id=track_id,
|
|
2419
|
+
error=str(type(e)),
|
|
2420
|
+
message=str(e),
|
|
2421
|
+
)
|
|
2422
|
+
raise e
|
|
2423
|
+
|
|
2424
|
+
with Uploader(
|
|
2425
|
+
upload_f=_upload_f,
|
|
2426
|
+
notify_f=_notify_f,
|
|
2427
|
+
exception_handler=_exception_handler,
|
|
2428
|
+
logger=logger,
|
|
2429
|
+
) as uploader:
|
|
2430
|
+
for batch in batched(
|
|
2431
|
+
range(
|
|
2432
|
+
start_frame_index, start_frame_index + direction * n_frames, direction * step
|
|
2433
|
+
),
|
|
2434
|
+
batch_size,
|
|
2435
|
+
):
|
|
2436
|
+
if inference_request.is_stopped():
|
|
2437
|
+
logger.debug(
|
|
2438
|
+
f"Cancelling inference video...",
|
|
2439
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2440
|
+
)
|
|
2441
|
+
break
|
|
2442
|
+
logger.debug(
|
|
2443
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
2444
|
+
)
|
|
2445
|
+
frames = self.cache.download_frames(
|
|
2446
|
+
api, video_info.id, batch, redownload_video=True
|
|
2447
|
+
)
|
|
2448
|
+
anns, slides_data = self._inference_auto(
|
|
2449
|
+
source=frames,
|
|
2450
|
+
settings=inference_settings,
|
|
2451
|
+
)
|
|
2452
|
+
|
|
2453
|
+
if inference_request.tracker is not None:
|
|
2454
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2455
|
+
|
|
2456
|
+
predictions = [
|
|
2457
|
+
Prediction(
|
|
2458
|
+
ann,
|
|
2459
|
+
model_meta=self.model_meta,
|
|
2460
|
+
frame_index=frame_index,
|
|
2461
|
+
video_id=video_info.id,
|
|
2462
|
+
dataset_id=video_info.dataset_id,
|
|
2463
|
+
project_id=video_info.project_id,
|
|
2464
|
+
)
|
|
2465
|
+
for ann, frame_index in zip(anns, batch)
|
|
2466
|
+
]
|
|
2467
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2468
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2469
|
+
uploader.put(predictions)
|
|
2470
|
+
video_ann_json = None
|
|
2471
|
+
if inference_request.tracker is not None:
|
|
2164
2472
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2165
|
-
video_ann_json =
|
|
2473
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2166
2474
|
inference_request.done()
|
|
2167
2475
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2168
2476
|
return video_ann_json
|
|
@@ -2188,10 +2496,9 @@ class Inference:
|
|
|
2188
2496
|
upload_mode = state.get("upload_mode", None)
|
|
2189
2497
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2190
2498
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2191
|
-
iou_merge_threshold =
|
|
2499
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
|
|
2192
2500
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
|
2193
2501
|
|
|
2194
|
-
project_info = api.project.get_info_by_id(project_id)
|
|
2195
2502
|
inference_request.context.setdefault("project_info", {})[project_id] = project_info
|
|
2196
2503
|
dataset_ids = state.get("dataset_ids", None)
|
|
2197
2504
|
if dataset_ids is None:
|
|
@@ -2226,7 +2533,11 @@ class Inference:
|
|
|
2226
2533
|
|
|
2227
2534
|
if cache_project_on_model:
|
|
2228
2535
|
download_to_cache(
|
|
2229
|
-
api,
|
|
2536
|
+
api,
|
|
2537
|
+
project_info.id,
|
|
2538
|
+
datasets_infos,
|
|
2539
|
+
progress_cb=inference_request.done,
|
|
2540
|
+
skip_create_readme=True,
|
|
2230
2541
|
)
|
|
2231
2542
|
|
|
2232
2543
|
images_infos_dict = {}
|
|
@@ -2235,20 +2546,9 @@ class Inference:
|
|
|
2235
2546
|
if not cache_project_on_model:
|
|
2236
2547
|
inference_request.done(dataset_info.items_count)
|
|
2237
2548
|
|
|
2238
|
-
def
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
with ThreadPoolExecutor(max(8, min(batch_size, 64))) as executor:
|
|
2242
|
-
for image_id in image_ids:
|
|
2243
|
-
executor.submit(
|
|
2244
|
-
self.cache.download_image,
|
|
2245
|
-
api,
|
|
2246
|
-
image_id,
|
|
2247
|
-
)
|
|
2248
|
-
|
|
2249
|
-
if not cache_project_on_model:
|
|
2250
|
-
# start downloading in parallel
|
|
2251
|
-
threading.Thread(target=_download_images, args=[datasets_infos], daemon=True).start()
|
|
2549
|
+
def download_f(item: int):
|
|
2550
|
+
self.cache.download_image(api, item)
|
|
2551
|
+
return item
|
|
2252
2552
|
|
|
2253
2553
|
_upload_predictions = partial(
|
|
2254
2554
|
self.upload_predictions,
|
|
@@ -2263,7 +2563,9 @@ class Inference:
|
|
|
2263
2563
|
)
|
|
2264
2564
|
|
|
2265
2565
|
_add_results_to_request = partial(
|
|
2266
|
-
self.add_results_to_request,
|
|
2566
|
+
self.add_results_to_request,
|
|
2567
|
+
inference_request=inference_request,
|
|
2568
|
+
progress_cb=inference_request.done,
|
|
2267
2569
|
)
|
|
2268
2570
|
|
|
2269
2571
|
if upload_mode is None:
|
|
@@ -2271,57 +2573,78 @@ class Inference:
|
|
|
2271
2573
|
else:
|
|
2272
2574
|
upload_f = _upload_predictions
|
|
2273
2575
|
|
|
2576
|
+
download_workers = max(8, min(batch_size, 64))
|
|
2274
2577
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, inference_progress_total)
|
|
2275
2578
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
2276
|
-
|
|
2277
|
-
for
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
project_info.id,
|
|
2293
|
-
dataset_info.name,
|
|
2294
|
-
[ii.name for ii in images_infos_batch],
|
|
2579
|
+
with Downloader(download_f, max_workers=download_workers, logger=logger) as downloader:
|
|
2580
|
+
for images in images_infos_dict.values():
|
|
2581
|
+
for image in images:
|
|
2582
|
+
downloader.put(image.id)
|
|
2583
|
+
downloader.next(100)
|
|
2584
|
+
for dataset_info in datasets_infos:
|
|
2585
|
+
for images_infos_batch in batched(
|
|
2586
|
+
images_infos_dict[dataset_info.id], batch_size=batch_size
|
|
2587
|
+
):
|
|
2588
|
+
if uploader.has_exception():
|
|
2589
|
+
exception = uploader.exception
|
|
2590
|
+
raise exception
|
|
2591
|
+
if inference_request.is_stopped():
|
|
2592
|
+
logger.debug(
|
|
2593
|
+
f"Cancelling inference project...",
|
|
2594
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2295
2595
|
)
|
|
2596
|
+
return
|
|
2597
|
+
if inference_request.is_paused():
|
|
2598
|
+
logger.info("Inference request is paused. Waiting...")
|
|
2599
|
+
while inference_request.is_paused():
|
|
2600
|
+
if (
|
|
2601
|
+
inference_request.paused_for()
|
|
2602
|
+
> inference_request.PAUSE_SLEEP_MAX_WAIT
|
|
2603
|
+
):
|
|
2604
|
+
logger.info(
|
|
2605
|
+
"Inference request has been paused for too long. Cancelling..."
|
|
2606
|
+
)
|
|
2607
|
+
raise RuntimeError(
|
|
2608
|
+
"Inference request cancelled due to long pause."
|
|
2609
|
+
)
|
|
2610
|
+
time.sleep(inference_request.PAUSE_SLEEP_INTERVAL)
|
|
2611
|
+
if cache_project_on_model:
|
|
2612
|
+
images_paths, _ = zip(
|
|
2613
|
+
*read_from_cached_project(
|
|
2614
|
+
project_info.id,
|
|
2615
|
+
dataset_info.name,
|
|
2616
|
+
[ii.name for ii in images_infos_batch],
|
|
2617
|
+
)
|
|
2618
|
+
)
|
|
2619
|
+
images_nps = [sly_image.read(img_path) for img_path in images_paths]
|
|
2620
|
+
else:
|
|
2621
|
+
images_nps = self.cache.download_images(
|
|
2622
|
+
api,
|
|
2623
|
+
dataset_info.id,
|
|
2624
|
+
[info.id for info in images_infos_batch],
|
|
2625
|
+
return_images=True,
|
|
2626
|
+
)
|
|
2627
|
+
downloader.next(len(images_infos_batch))
|
|
2628
|
+
anns, slides_data = self._inference_auto(
|
|
2629
|
+
source=images_nps,
|
|
2630
|
+
settings=inference_settings,
|
|
2296
2631
|
)
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
|
|
2309
|
-
|
|
2310
|
-
|
|
2311
|
-
ann,
|
|
2312
|
-
model_meta=self.model_meta,
|
|
2313
|
-
image_id=image_info.id,
|
|
2314
|
-
name=image_info.name,
|
|
2315
|
-
dataset_id=dataset_info.id,
|
|
2316
|
-
project_id=dataset_info.project_id,
|
|
2317
|
-
image_name=image_info.name,
|
|
2318
|
-
)
|
|
2319
|
-
for ann, image_info in zip(anns, images_infos_batch)
|
|
2320
|
-
]
|
|
2321
|
-
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2322
|
-
pred.extra_data["slides_data"] = this_slides_data
|
|
2632
|
+
predictions = [
|
|
2633
|
+
Prediction(
|
|
2634
|
+
ann,
|
|
2635
|
+
model_meta=self.model_meta,
|
|
2636
|
+
image_id=image_info.id,
|
|
2637
|
+
name=image_info.name,
|
|
2638
|
+
dataset_id=dataset_info.id,
|
|
2639
|
+
project_id=dataset_info.project_id,
|
|
2640
|
+
image_name=image_info.name,
|
|
2641
|
+
)
|
|
2642
|
+
for ann, image_info in zip(anns, images_infos_batch)
|
|
2643
|
+
]
|
|
2644
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2645
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2323
2646
|
|
|
2324
|
-
|
|
2647
|
+
uploader.put(predictions)
|
|
2325
2648
|
|
|
2326
2649
|
def _run_speedtest(
|
|
2327
2650
|
self,
|
|
@@ -2364,7 +2687,13 @@ class Inference:
|
|
|
2364
2687
|
inference_request.done()
|
|
2365
2688
|
|
|
2366
2689
|
if cache_project_on_model:
|
|
2367
|
-
download_to_cache(
|
|
2690
|
+
download_to_cache(
|
|
2691
|
+
api,
|
|
2692
|
+
project_id,
|
|
2693
|
+
datasets_infos,
|
|
2694
|
+
progress_cb=inference_request.done,
|
|
2695
|
+
skip_create_readme=True,
|
|
2696
|
+
)
|
|
2368
2697
|
|
|
2369
2698
|
inference_request.set_stage("warmup", 0, num_warmup)
|
|
2370
2699
|
|
|
@@ -2485,6 +2814,11 @@ class Inference:
|
|
|
2485
2814
|
def _freeze_model(self):
|
|
2486
2815
|
if self._model_frozen or not self._model_served:
|
|
2487
2816
|
return
|
|
2817
|
+
|
|
2818
|
+
if not self._deploy_params:
|
|
2819
|
+
logger.warning("Deploy params are not set, cannot freeze the model.")
|
|
2820
|
+
return
|
|
2821
|
+
|
|
2488
2822
|
logger.debug("Freezing model...")
|
|
2489
2823
|
runtime = self._deploy_params.get("runtime")
|
|
2490
2824
|
if runtime and runtime.lower() != RuntimeType.PYTORCH.lower():
|
|
@@ -2524,7 +2858,6 @@ class Inference:
|
|
|
2524
2858
|
timer.daemon = True
|
|
2525
2859
|
timer.start()
|
|
2526
2860
|
self._freeze_timer = timer
|
|
2527
|
-
logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
|
|
2528
2861
|
|
|
2529
2862
|
def _set_served_callback(self):
|
|
2530
2863
|
self._model_served = True
|
|
@@ -2637,6 +2970,10 @@ class Inference:
|
|
|
2637
2970
|
for prediction in predictions:
|
|
2638
2971
|
ds_predictions[prediction.dataset_id].append(prediction)
|
|
2639
2972
|
|
|
2973
|
+
def update_labeling_status(ann: Annotation) -> Annotation:
|
|
2974
|
+
for label in ann.labels:
|
|
2975
|
+
label.status = LabelingStatus.AUTO
|
|
2976
|
+
|
|
2640
2977
|
def _new_name(image_info: ImageInfo):
|
|
2641
2978
|
name = Path(image_info.name)
|
|
2642
2979
|
stem = name.stem
|
|
@@ -2669,10 +3006,10 @@ class Inference:
|
|
|
2669
3006
|
context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
|
|
2670
3007
|
return created_dataset.id
|
|
2671
3008
|
|
|
2672
|
-
created_names = []
|
|
2673
3009
|
if context is None:
|
|
2674
3010
|
context = {}
|
|
2675
3011
|
for dataset_id, preds in ds_predictions.items():
|
|
3012
|
+
created_names = set()
|
|
2676
3013
|
if dst_project_id is not None:
|
|
2677
3014
|
# upload to the destination project
|
|
2678
3015
|
dst_dataset_id = _get_or_create_dataset(
|
|
@@ -2712,8 +3049,15 @@ class Inference:
|
|
|
2712
3049
|
iou=iou_merge_threshold,
|
|
2713
3050
|
meta=project_meta,
|
|
2714
3051
|
)
|
|
3052
|
+
|
|
3053
|
+
# Update labeling status of new predictions before upload
|
|
3054
|
+
anns_with_nn_flags = []
|
|
2715
3055
|
for pred, ann in zip(preds, anns):
|
|
3056
|
+
update_labeling_status(ann)
|
|
2716
3057
|
pred.annotation = ann
|
|
3058
|
+
anns_with_nn_flags.append(ann)
|
|
3059
|
+
|
|
3060
|
+
anns = anns_with_nn_flags
|
|
2717
3061
|
|
|
2718
3062
|
context.setdefault("image_info", {})
|
|
2719
3063
|
missing = [
|
|
@@ -2741,7 +3085,7 @@ class Inference:
|
|
|
2741
3085
|
with_annotations=False,
|
|
2742
3086
|
save_source_date=False,
|
|
2743
3087
|
)
|
|
2744
|
-
created_names.
|
|
3088
|
+
created_names.update([image_info.name for image_info in dst_image_infos])
|
|
2745
3089
|
api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
|
|
2746
3090
|
else:
|
|
2747
3091
|
# upload to the source dataset
|
|
@@ -2778,7 +3122,10 @@ class Inference:
|
|
|
2778
3122
|
iou=iou_merge_threshold,
|
|
2779
3123
|
meta=project_meta,
|
|
2780
3124
|
)
|
|
3125
|
+
|
|
3126
|
+
# Update labeling status of predicted labels before optional merge
|
|
2781
3127
|
for pred, ann in zip(preds, anns):
|
|
3128
|
+
update_labeling_status(ann)
|
|
2782
3129
|
pred.annotation = ann
|
|
2783
3130
|
|
|
2784
3131
|
if upload_mode in ["iou_merge", "append"]:
|
|
@@ -2814,11 +3161,89 @@ class Inference:
|
|
|
2814
3161
|
inference_request.add_results(results)
|
|
2815
3162
|
|
|
2816
3163
|
def add_results_to_request(
|
|
2817
|
-
self, predictions: List[Prediction], inference_request: InferenceRequest
|
|
3164
|
+
self, predictions: List[Prediction], inference_request: InferenceRequest, progress_cb=None
|
|
2818
3165
|
):
|
|
2819
3166
|
results = self._format_output(predictions)
|
|
2820
3167
|
inference_request.add_results(results)
|
|
2821
|
-
|
|
3168
|
+
if progress_cb:
|
|
3169
|
+
progress_cb(len(results))
|
|
3170
|
+
|
|
3171
|
+
def upload_predictions_to_video(
|
|
3172
|
+
self,
|
|
3173
|
+
predictions: List[Prediction],
|
|
3174
|
+
api: Api,
|
|
3175
|
+
video_info: VideoInfo,
|
|
3176
|
+
track_id: str,
|
|
3177
|
+
context: Dict,
|
|
3178
|
+
progress_cb=None,
|
|
3179
|
+
inference_request: InferenceRequest = None,
|
|
3180
|
+
):
|
|
3181
|
+
key_id_map = KeyIdMap()
|
|
3182
|
+
project_meta = context.get("project_meta", None)
|
|
3183
|
+
if project_meta is None:
|
|
3184
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
|
|
3185
|
+
context["project_meta"] = project_meta
|
|
3186
|
+
meta_changed = False
|
|
3187
|
+
for prediction in predictions:
|
|
3188
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3189
|
+
project_meta, prediction.annotation, None
|
|
3190
|
+
)
|
|
3191
|
+
prediction.annotation = ann
|
|
3192
|
+
meta_changed = meta_changed or meta_changed_
|
|
3193
|
+
if meta_changed:
|
|
3194
|
+
project_meta = api.project.update_meta(video_info.project_id, project_meta)
|
|
3195
|
+
context["project_meta"] = project_meta
|
|
3196
|
+
|
|
3197
|
+
figure_data_by_object_id = defaultdict(list)
|
|
3198
|
+
|
|
3199
|
+
tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
|
|
3200
|
+
new_tracks: Dict[int, VideoObject] = {}
|
|
3201
|
+
for prediction in predictions:
|
|
3202
|
+
annotation = prediction.annotation
|
|
3203
|
+
tracks = annotation.custom_data
|
|
3204
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3205
|
+
if track not in tracks_to_object_ids and track not in new_tracks:
|
|
3206
|
+
video_object = VideoObject(obj_class=label.obj_class)
|
|
3207
|
+
new_tracks[track] = video_object
|
|
3208
|
+
if new_tracks:
|
|
3209
|
+
tracks, video_objects = zip(*new_tracks.items())
|
|
3210
|
+
added_object_ids = api.video.object.append_bulk(
|
|
3211
|
+
video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
|
|
3212
|
+
)
|
|
3213
|
+
for track, object_id in zip(tracks, added_object_ids):
|
|
3214
|
+
tracks_to_object_ids[track] = object_id
|
|
3215
|
+
for prediction in predictions:
|
|
3216
|
+
annotation = prediction.annotation
|
|
3217
|
+
tracks = annotation.custom_data
|
|
3218
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3219
|
+
object_id = tracks_to_object_ids[track]
|
|
3220
|
+
figure_data_by_object_id[object_id].append(
|
|
3221
|
+
{
|
|
3222
|
+
ApiField.OBJECT_ID: object_id,
|
|
3223
|
+
ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
|
|
3224
|
+
ApiField.GEOMETRY: label.geometry.to_json(),
|
|
3225
|
+
ApiField.META: {ApiField.FRAME: prediction.frame_index},
|
|
3226
|
+
ApiField.TRACK_ID: track_id,
|
|
3227
|
+
}
|
|
3228
|
+
)
|
|
3229
|
+
|
|
3230
|
+
for object_id, figures_data in figure_data_by_object_id.items():
|
|
3231
|
+
figures_keys = [uuid.uuid4() for _ in figures_data]
|
|
3232
|
+
api.video.figure._append_bulk(
|
|
3233
|
+
entity_id=video_info.id,
|
|
3234
|
+
figures_json=figures_data,
|
|
3235
|
+
figures_keys=figures_keys,
|
|
3236
|
+
key_id_map=key_id_map,
|
|
3237
|
+
)
|
|
3238
|
+
logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
|
|
3239
|
+
if progress_cb:
|
|
3240
|
+
progress_cb(len(predictions))
|
|
3241
|
+
if inference_request is not None:
|
|
3242
|
+
results = self._format_output(predictions)
|
|
3243
|
+
for result in results:
|
|
3244
|
+
result["annotation"] = None
|
|
3245
|
+
result["data"] = None
|
|
3246
|
+
inference_request.add_results(results)
|
|
2822
3247
|
|
|
2823
3248
|
def serve(self):
|
|
2824
3249
|
if not self._use_gui and not self._is_cli_deploy:
|
|
@@ -2902,7 +3327,7 @@ class Inference:
|
|
|
2902
3327
|
|
|
2903
3328
|
if not self._use_gui:
|
|
2904
3329
|
Progress("Model deployed", 1).iter_done_report()
|
|
2905
|
-
|
|
3330
|
+
elif self.api is not None:
|
|
2906
3331
|
autostart_func()
|
|
2907
3332
|
|
|
2908
3333
|
@server.exception_handler(HTTPException)
|
|
@@ -2929,6 +3354,11 @@ class Inference:
|
|
|
2929
3354
|
def get_session_info(response: Response):
|
|
2930
3355
|
return self.get_info()
|
|
2931
3356
|
|
|
3357
|
+
@server.post("/get_tracking_settings")
|
|
3358
|
+
@self._check_serve_before_call
|
|
3359
|
+
def get_tracking_settings(response: Response):
|
|
3360
|
+
return self.get_tracking_settings()
|
|
3361
|
+
|
|
2932
3362
|
@server.post("/get_custom_inference_settings")
|
|
2933
3363
|
def get_custom_inference_settings():
|
|
2934
3364
|
return {"settings": self.custom_inference_settings}
|
|
@@ -3212,6 +3642,22 @@ class Inference:
|
|
|
3212
3642
|
"inference_request_uuid": inference_request.uuid,
|
|
3213
3643
|
}
|
|
3214
3644
|
|
|
3645
|
+
@server.post("/tracking_by_detection")
|
|
3646
|
+
def tracking_by_detection(response: Response, request: Request):
|
|
3647
|
+
state = request.state.state
|
|
3648
|
+
context = request.state.context
|
|
3649
|
+
state.update(context)
|
|
3650
|
+
if state.get("tracker") is None:
|
|
3651
|
+
state["tracker"] = "botsort"
|
|
3652
|
+
|
|
3653
|
+
logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
|
|
3654
|
+
self.validate_inference_state(state)
|
|
3655
|
+
api = self.api_from_request(request)
|
|
3656
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
3657
|
+
self._tracking_by_detection, api, state
|
|
3658
|
+
)
|
|
3659
|
+
return {"message": "Track task started."}
|
|
3660
|
+
|
|
3215
3661
|
@server.post("/inference_project_id_async")
|
|
3216
3662
|
def inference_project_id_async(response: Response, request: Request):
|
|
3217
3663
|
state = request.state.state
|
|
@@ -3275,10 +3721,7 @@ class Inference:
|
|
|
3275
3721
|
data = {**inference_request.to_json(), **log_extra}
|
|
3276
3722
|
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
3277
3723
|
data["progress"] = {"current": 0, "total": 1}
|
|
3278
|
-
logger.debug(
|
|
3279
|
-
f"Sending inference progress with uuid:",
|
|
3280
|
-
extra=data,
|
|
3281
|
-
)
|
|
3724
|
+
logger.debug(f"Sending inference progress with uuid:", extra=data)
|
|
3282
3725
|
return data
|
|
3283
3726
|
|
|
3284
3727
|
@server.post(f"/pop_inference_results")
|
|
@@ -4135,20 +4578,20 @@ class Inference:
|
|
|
4135
4578
|
self._args.draw,
|
|
4136
4579
|
)
|
|
4137
4580
|
|
|
4138
|
-
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation]):
|
|
4581
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
|
|
4139
4582
|
updated_anns = []
|
|
4140
4583
|
for frame, ann in zip(frames, anns):
|
|
4141
|
-
matches =
|
|
4584
|
+
matches = tracker.update(frame, ann)
|
|
4142
4585
|
track_ids = [match["track_id"] for match in matches]
|
|
4143
4586
|
tracked_labels = [match["label"] for match in matches]
|
|
4144
|
-
|
|
4587
|
+
|
|
4145
4588
|
filtered_annotation = ann.clone(
|
|
4146
4589
|
labels=tracked_labels,
|
|
4147
4590
|
custom_data=track_ids
|
|
4148
4591
|
)
|
|
4149
4592
|
updated_anns.append(filtered_annotation)
|
|
4150
4593
|
return updated_anns
|
|
4151
|
-
|
|
4594
|
+
|
|
4152
4595
|
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
4153
4596
|
if model_source == ModelSource.PRETRAINED:
|
|
4154
4597
|
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
@@ -4198,62 +4641,78 @@ class Inference:
|
|
|
4198
4641
|
return
|
|
4199
4642
|
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
4200
4643
|
|
|
4644
|
+
def export_onnx(self, deploy_params: dict):
|
|
4645
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4201
4646
|
|
|
4202
|
-
def
|
|
4203
|
-
|
|
4204
|
-
|
|
4205
|
-
|
|
4206
|
-
|
|
4207
|
-
|
|
4208
|
-
meta: Optional[ProjectMeta] = None,
|
|
4647
|
+
def export_tensorrt(self, deploy_params: dict):
|
|
4648
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4649
|
+
|
|
4650
|
+
|
|
4651
|
+
def _filter_duplicated_predictions_from_ann_cpu(
|
|
4652
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
4209
4653
|
):
|
|
4210
4654
|
"""
|
|
4211
|
-
Filter out
|
|
4655
|
+
Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
|
|
4656
|
+
Uses Shapely for geometric operations.
|
|
4212
4657
|
|
|
4213
|
-
|
|
4214
|
-
|
|
4215
|
-
|
|
4216
|
-
|
|
4217
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4658
|
+
Args:
|
|
4659
|
+
pred_ann: Predicted annotation object
|
|
4660
|
+
gt_ann: Ground truth annotation object
|
|
4661
|
+
iou_threshold: IoU threshold for filtering
|
|
4218
4662
|
|
|
4219
|
-
:
|
|
4220
|
-
|
|
4221
|
-
:param pred_anns: List of Annotation objects containing predictions
|
|
4222
|
-
:type pred_anns: List[Annotation]
|
|
4223
|
-
:param dataset_id: ID of the dataset containing the images
|
|
4224
|
-
:type dataset_id: int
|
|
4225
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4226
|
-
:type gt_image_ids: List[int]
|
|
4227
|
-
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4228
|
-
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4229
|
-
:type iou: Optional[float]
|
|
4230
|
-
:param meta: ProjectMeta object
|
|
4231
|
-
:type meta: Optional[ProjectMeta]
|
|
4232
|
-
:return: List of Annotation objects containing filtered predictions
|
|
4233
|
-
:rtype: List[Annotation]
|
|
4234
|
-
|
|
4235
|
-
Notes:
|
|
4236
|
-
------
|
|
4237
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
4238
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4663
|
+
Returns:
|
|
4664
|
+
New annotation with filtered labels
|
|
4239
4665
|
"""
|
|
4240
|
-
if
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4245
|
-
|
|
4246
|
-
|
|
4247
|
-
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4666
|
+
if not iou_threshold:
|
|
4667
|
+
return pred_ann
|
|
4668
|
+
|
|
4669
|
+
from shapely.geometry import box
|
|
4670
|
+
|
|
4671
|
+
def calculate_iou(geom1: Geometry, geom2: Geometry):
|
|
4672
|
+
"""Calculate IoU between two geometries using Shapely."""
|
|
4673
|
+
bbox1 = geom1.to_bbox()
|
|
4674
|
+
bbox2 = geom2.to_bbox()
|
|
4675
|
+
|
|
4676
|
+
box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
|
|
4677
|
+
box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
|
|
4678
|
+
|
|
4679
|
+
intersection = box1.intersection(box2).area
|
|
4680
|
+
union = box1.union(box2).area
|
|
4681
|
+
|
|
4682
|
+
return intersection / union if union > 0 else 0.0
|
|
4683
|
+
|
|
4684
|
+
new_labels = []
|
|
4685
|
+
pred_cls_bboxes = defaultdict(list)
|
|
4686
|
+
for label in pred_ann.labels:
|
|
4687
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4688
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4689
|
+
|
|
4690
|
+
gt_cls_bboxes = defaultdict(list)
|
|
4691
|
+
for label in gt_ann.labels:
|
|
4692
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4693
|
+
if name_shape not in pred_cls_bboxes:
|
|
4694
|
+
continue
|
|
4695
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4696
|
+
|
|
4697
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4698
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4699
|
+
if len(gt) == 0:
|
|
4700
|
+
new_labels.extend(pred)
|
|
4701
|
+
continue
|
|
4702
|
+
|
|
4703
|
+
for pred_label in pred:
|
|
4704
|
+
# Check if this prediction has IoU < threshold with ALL GT boxes
|
|
4705
|
+
keep = True
|
|
4706
|
+
for gt_label in gt:
|
|
4707
|
+
iou = calculate_iou(pred_label.geometry, gt_label.geometry)
|
|
4708
|
+
if iou >= iou_threshold:
|
|
4709
|
+
keep = False
|
|
4710
|
+
break
|
|
4711
|
+
|
|
4712
|
+
if keep:
|
|
4713
|
+
new_labels.append(pred_label)
|
|
4714
|
+
|
|
4715
|
+
return pred_ann.clone(labels=new_labels)
|
|
4257
4716
|
|
|
4258
4717
|
|
|
4259
4718
|
def _filter_duplicated_predictions_from_ann(
|
|
@@ -4284,13 +4743,15 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4284
4743
|
- Predictions with classes not present in ground truth will be kept
|
|
4285
4744
|
- Requires PyTorch and torchvision for IoU calculations
|
|
4286
4745
|
"""
|
|
4746
|
+
if not iou_threshold:
|
|
4747
|
+
return pred_ann
|
|
4287
4748
|
|
|
4288
4749
|
try:
|
|
4289
4750
|
import torch
|
|
4290
4751
|
from torchvision.ops import box_iou
|
|
4291
4752
|
|
|
4292
4753
|
except ImportError:
|
|
4293
|
-
|
|
4754
|
+
return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
|
|
4294
4755
|
|
|
4295
4756
|
def _to_tensor(geom):
|
|
4296
4757
|
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
@@ -4298,16 +4759,18 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4298
4759
|
new_labels = []
|
|
4299
4760
|
pred_cls_bboxes = defaultdict(list)
|
|
4300
4761
|
for label in pred_ann.labels:
|
|
4301
|
-
|
|
4762
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4763
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4302
4764
|
|
|
4303
4765
|
gt_cls_bboxes = defaultdict(list)
|
|
4304
4766
|
for label in gt_ann.labels:
|
|
4305
|
-
|
|
4767
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4768
|
+
if name_shape not in pred_cls_bboxes:
|
|
4306
4769
|
continue
|
|
4307
|
-
gt_cls_bboxes[
|
|
4770
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4308
4771
|
|
|
4309
|
-
for
|
|
4310
|
-
gt = gt_cls_bboxes[
|
|
4772
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4773
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4311
4774
|
if len(gt) == 0:
|
|
4312
4775
|
new_labels.extend(pred)
|
|
4313
4776
|
continue
|
|
@@ -4321,6 +4784,63 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4321
4784
|
return pred_ann.clone(labels=new_labels)
|
|
4322
4785
|
|
|
4323
4786
|
|
|
4787
|
+
def _exclude_duplicated_predictions(
|
|
4788
|
+
api: Api,
|
|
4789
|
+
pred_anns: List[Annotation],
|
|
4790
|
+
dataset_id: int,
|
|
4791
|
+
gt_image_ids: List[int],
|
|
4792
|
+
iou: float = None,
|
|
4793
|
+
meta: Optional[ProjectMeta] = None,
|
|
4794
|
+
):
|
|
4795
|
+
"""
|
|
4796
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
4797
|
+
|
|
4798
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4799
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4800
|
+
- Gets ProjectMeta object if not provided
|
|
4801
|
+
- Downloads GT annotations for the specified image IDs
|
|
4802
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4803
|
+
|
|
4804
|
+
:param api: Supervisely API object
|
|
4805
|
+
:type api: Api
|
|
4806
|
+
:param pred_anns: List of Annotation objects containing predictions
|
|
4807
|
+
:type pred_anns: List[Annotation]
|
|
4808
|
+
:param dataset_id: ID of the dataset containing the images
|
|
4809
|
+
:type dataset_id: int
|
|
4810
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4811
|
+
:type gt_image_ids: List[int]
|
|
4812
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4813
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4814
|
+
:type iou: Optional[float]
|
|
4815
|
+
:param meta: ProjectMeta object
|
|
4816
|
+
:type meta: Optional[ProjectMeta]
|
|
4817
|
+
:return: List of Annotation objects containing filtered predictions
|
|
4818
|
+
:rtype: List[Annotation]
|
|
4819
|
+
|
|
4820
|
+
Notes:
|
|
4821
|
+
------
|
|
4822
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
4823
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4824
|
+
"""
|
|
4825
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
|
4826
|
+
if meta is None:
|
|
4827
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
|
4828
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
4829
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
4830
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
4831
|
+
for i in range(0, len(pred_anns)):
|
|
4832
|
+
before = len(pred_anns[i].labels)
|
|
4833
|
+
with Timer() as timer:
|
|
4834
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
|
4835
|
+
gt_anns[i], pred_anns[i], iou
|
|
4836
|
+
)
|
|
4837
|
+
after = len(pred_anns[i].labels)
|
|
4838
|
+
logger.debug(
|
|
4839
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
4840
|
+
)
|
|
4841
|
+
return pred_anns
|
|
4842
|
+
|
|
4843
|
+
|
|
4324
4844
|
def _get_log_extra_for_inference_request(
|
|
4325
4845
|
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
|
4326
4846
|
):
|
|
@@ -4347,8 +4867,8 @@ def _get_log_extra_for_inference_request(
|
|
|
4347
4867
|
"has_result": inference_request.final_result is not None,
|
|
4348
4868
|
"pending_results": inference_request.pending_num(),
|
|
4349
4869
|
"exception": inference_request.exception_json(),
|
|
4350
|
-
"result": inference_request._final_result,
|
|
4351
4870
|
"preparing_progress": progress,
|
|
4871
|
+
"result": inference_request.final_result is not None, # for backward compatibility
|
|
4352
4872
|
}
|
|
4353
4873
|
return log_extra
|
|
4354
4874
|
|
|
@@ -4428,7 +4948,7 @@ def get_gpu_count():
|
|
|
4428
4948
|
gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
|
|
4429
4949
|
return gpu_count
|
|
4430
4950
|
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
|
4431
|
-
logger.
|
|
4951
|
+
logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
|
|
4432
4952
|
return 0
|
|
4433
4953
|
|
|
4434
4954
|
|
|
@@ -4608,7 +5128,180 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suf
|
|
|
4608
5128
|
img_tags = None
|
|
4609
5129
|
if not any_label_updated:
|
|
4610
5130
|
labels = None
|
|
4611
|
-
ann = ann.clone(img_tags=
|
|
5131
|
+
ann = ann.clone(img_tags=img_tags)
|
|
5132
|
+
return meta, ann, meta_changed
|
|
5133
|
+
|
|
5134
|
+
|
|
5135
|
+
def update_meta_and_ann_for_video_annotation(
|
|
5136
|
+
meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
|
|
5137
|
+
):
|
|
5138
|
+
"""Update project meta and annotation to match each other
|
|
5139
|
+
If obj class or tag meta from annotation conflicts with project meta
|
|
5140
|
+
add suffix to obj class or tag meta.
|
|
5141
|
+
Return tuple of updated project meta, annotation and boolean flag if meta was changed.
|
|
5142
|
+
"""
|
|
5143
|
+
obj_classes_suffixes = ["_nn"]
|
|
5144
|
+
tag_meta_suffixes = ["_nn"]
|
|
5145
|
+
if model_prediction_suffix is not None:
|
|
5146
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
5147
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
5148
|
+
logger.debug(
|
|
5149
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
5150
|
+
)
|
|
5151
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
5152
|
+
meta_changed = False
|
|
5153
|
+
|
|
5154
|
+
# meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
5155
|
+
# if replaced_classes_in_meta:
|
|
5156
|
+
# meta_changed = True
|
|
5157
|
+
# logger.warning(
|
|
5158
|
+
# "Some classes names were fixed in project meta",
|
|
5159
|
+
# extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
5160
|
+
# )
|
|
5161
|
+
|
|
5162
|
+
new_objects: List[VideoObject] = []
|
|
5163
|
+
new_figures: List[VideoFigure] = []
|
|
5164
|
+
any_object_updated = False
|
|
5165
|
+
for video_object in ann.objects:
|
|
5166
|
+
this_object_figures = [
|
|
5167
|
+
figure for figure in ann.figures if figure.video_object.key() == video_object.key()
|
|
5168
|
+
]
|
|
5169
|
+
this_object_changed = False
|
|
5170
|
+
original_obj_class_name = video_object.obj_class.name
|
|
5171
|
+
suffix_found = False
|
|
5172
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
5173
|
+
obj_class = video_object.obj_class
|
|
5174
|
+
obj_class_name = obj_class.name + suffix
|
|
5175
|
+
if suffix:
|
|
5176
|
+
obj_class = obj_class.clone(name=obj_class_name)
|
|
5177
|
+
video_object = video_object.clone(obj_class=obj_class)
|
|
5178
|
+
any_object_updated = True
|
|
5179
|
+
this_object_changed = True
|
|
5180
|
+
meta_obj_class = meta.get_obj_class(obj_class_name)
|
|
5181
|
+
if meta_obj_class is None:
|
|
5182
|
+
# obj class is not in meta, add it with suffix
|
|
5183
|
+
meta = meta.add_obj_class(obj_class)
|
|
5184
|
+
new_objects.append(video_object)
|
|
5185
|
+
meta_changed = True
|
|
5186
|
+
suffix_found = True
|
|
5187
|
+
break
|
|
5188
|
+
elif (
|
|
5189
|
+
meta_obj_class.geometry_type.geometry_name()
|
|
5190
|
+
== video_object.obj_class.geometry_type.geometry_name()
|
|
5191
|
+
):
|
|
5192
|
+
# if object geometry is the same as in meta, use meta obj class
|
|
5193
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5194
|
+
new_objects.append(video_object)
|
|
5195
|
+
suffix_found = True
|
|
5196
|
+
any_object_updated = True
|
|
5197
|
+
this_object_changed = True
|
|
5198
|
+
break
|
|
5199
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
5200
|
+
# if meta obj class is AnyGeometry, use it in object
|
|
5201
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5202
|
+
new_objects.append(video_object)
|
|
5203
|
+
suffix_found = True
|
|
5204
|
+
any_object_updated = True
|
|
5205
|
+
this_object_changed = True
|
|
5206
|
+
break
|
|
5207
|
+
if not suffix_found:
|
|
5208
|
+
# if no suffix found, raise error
|
|
5209
|
+
raise ValueError(
|
|
5210
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
5211
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
5212
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
5213
|
+
)
|
|
5214
|
+
elif this_object_changed:
|
|
5215
|
+
this_object_figures = [
|
|
5216
|
+
figure.clone(video_object=video_object) for figure in this_object_figures
|
|
5217
|
+
]
|
|
5218
|
+
new_figures.extend(this_object_figures)
|
|
5219
|
+
if any_object_updated:
|
|
5220
|
+
frames_figures = {}
|
|
5221
|
+
for figure in new_figures:
|
|
5222
|
+
frames_figures.setdefault(figure.frame_index, []).append(figure)
|
|
5223
|
+
new_frames = FrameCollection(
|
|
5224
|
+
[
|
|
5225
|
+
Frame(index=frame_index, figures=figures)
|
|
5226
|
+
for frame_index, figures in frames_figures.items()
|
|
5227
|
+
]
|
|
5228
|
+
)
|
|
5229
|
+
ann = ann.clone(objects=new_objects, frames=new_frames)
|
|
5230
|
+
|
|
5231
|
+
# check if tag metas are in project meta
|
|
5232
|
+
# if not, add them with suffix
|
|
5233
|
+
ann_tag_metas: Dict[str, TagMeta] = {}
|
|
5234
|
+
for video_object in ann.objects:
|
|
5235
|
+
for tag in video_object.tags:
|
|
5236
|
+
tag_name = tag.meta.name
|
|
5237
|
+
if tag_name not in ann_tag_metas:
|
|
5238
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5239
|
+
for tag in ann.tags:
|
|
5240
|
+
tag_name = tag.meta.name
|
|
5241
|
+
if tag_name not in ann_tag_metas:
|
|
5242
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5243
|
+
|
|
5244
|
+
changed_tag_metas = {}
|
|
5245
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
5246
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
5247
|
+
if meta_tag_meta is None:
|
|
5248
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
5249
|
+
meta_changed = True
|
|
5250
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5251
|
+
suffix_found = False
|
|
5252
|
+
for suffix in tag_meta_suffixes:
|
|
5253
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
5254
|
+
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
5255
|
+
if meta_tag_meta is None:
|
|
5256
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
5257
|
+
meta = meta.add_tag_meta(new_tag_meta)
|
|
5258
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
5259
|
+
meta_changed = True
|
|
5260
|
+
suffix_found = True
|
|
5261
|
+
break
|
|
5262
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5263
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
5264
|
+
suffix_found = True
|
|
5265
|
+
break
|
|
5266
|
+
if not suffix_found:
|
|
5267
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
5268
|
+
|
|
5269
|
+
if changed_tag_metas:
|
|
5270
|
+
objects = []
|
|
5271
|
+
any_object_updated = False
|
|
5272
|
+
for video_object in ann.objects:
|
|
5273
|
+
any_tag_updated = False
|
|
5274
|
+
object_tags = []
|
|
5275
|
+
for tag in video_object.tags:
|
|
5276
|
+
if tag.meta.name in changed_tag_metas:
|
|
5277
|
+
object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5278
|
+
any_tag_updated = True
|
|
5279
|
+
else:
|
|
5280
|
+
object_tags.append(tag)
|
|
5281
|
+
if any_tag_updated:
|
|
5282
|
+
video_object = video_object.clone(tags=TagCollection(object_tags))
|
|
5283
|
+
any_object_updated = True
|
|
5284
|
+
objects.append(video_object)
|
|
5285
|
+
|
|
5286
|
+
video_tags = []
|
|
5287
|
+
any_tag_updated = False
|
|
5288
|
+
for tag in ann.tags:
|
|
5289
|
+
if tag.meta.name in changed_tag_metas:
|
|
5290
|
+
video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5291
|
+
any_tag_updated = True
|
|
5292
|
+
else:
|
|
5293
|
+
video_tags.append(tag)
|
|
5294
|
+
if any_tag_updated or any_object_updated:
|
|
5295
|
+
if any_tag_updated:
|
|
5296
|
+
video_tags = VideoTagCollection(video_tags)
|
|
5297
|
+
else:
|
|
5298
|
+
video_tags = None
|
|
5299
|
+
if any_object_updated:
|
|
5300
|
+
objects = VideoObjectCollection(objects)
|
|
5301
|
+
else:
|
|
5302
|
+
objects = None
|
|
5303
|
+
ann = ann.clone(tags=video_tags, objects=objects)
|
|
5304
|
+
|
|
4612
5305
|
return meta, ann, meta_changed
|
|
4613
5306
|
|
|
4614
5307
|
|
|
@@ -4722,7 +5415,8 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
4722
5415
|
return data[key]
|
|
4723
5416
|
return None
|
|
4724
5417
|
|
|
4725
|
-
|
|
5418
|
+
|
|
5419
|
+
def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
|
|
4726
5420
|
import torch # pylint: disable=import-error
|
|
4727
5421
|
|
|
4728
5422
|
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|