supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ import subprocess
|
|
|
11
11
|
import tempfile
|
|
12
12
|
import threading
|
|
13
13
|
import time
|
|
14
|
+
import uuid
|
|
14
15
|
from collections import OrderedDict, defaultdict
|
|
15
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
16
17
|
from dataclasses import asdict, dataclass
|
|
@@ -19,6 +20,7 @@ from pathlib import Path
|
|
|
19
20
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
20
21
|
from urllib.request import urlopen
|
|
21
22
|
|
|
23
|
+
import _pickle
|
|
22
24
|
import numpy as np
|
|
23
25
|
import requests
|
|
24
26
|
import uvicorn
|
|
@@ -34,7 +36,6 @@ import supervisely.io.env as sly_env
|
|
|
34
36
|
import supervisely.io.fs as sly_fs
|
|
35
37
|
import supervisely.io.json as sly_json
|
|
36
38
|
import supervisely.nn.inference.gui as GUI
|
|
37
|
-
from supervisely.nn.experiments import ExperimentInfo
|
|
38
39
|
from supervisely import DatasetInfo, batched
|
|
39
40
|
from supervisely._utils import (
|
|
40
41
|
add_callback,
|
|
@@ -45,13 +46,14 @@ from supervisely._utils import (
|
|
|
45
46
|
rand_str,
|
|
46
47
|
)
|
|
47
48
|
from supervisely.annotation.annotation import Annotation
|
|
48
|
-
from supervisely.annotation.label import Label
|
|
49
|
+
from supervisely.annotation.label import Label, LabelingStatus
|
|
49
50
|
from supervisely.annotation.obj_class import ObjClass
|
|
50
51
|
from supervisely.annotation.tag_collection import TagCollection
|
|
51
52
|
from supervisely.annotation.tag_meta import TagMeta, TagValueType
|
|
52
53
|
from supervisely.api.api import Api, ApiField
|
|
53
54
|
from supervisely.api.app_api import WorkflowMeta, WorkflowSettings
|
|
54
55
|
from supervisely.api.image_api import ImageInfo
|
|
56
|
+
from supervisely.api.video.video_api import VideoInfo
|
|
55
57
|
from supervisely.app.content import get_data_dir
|
|
56
58
|
from supervisely.app.fastapi.subapp import (
|
|
57
59
|
Application,
|
|
@@ -67,15 +69,17 @@ from supervisely.decorators.inference import (
|
|
|
67
69
|
process_images_batch_sliding_window,
|
|
68
70
|
)
|
|
69
71
|
from supervisely.geometry.any_geometry import AnyGeometry
|
|
72
|
+
from supervisely.geometry.geometry import Geometry
|
|
70
73
|
from supervisely.imaging.color import get_predefined_colors
|
|
71
74
|
from supervisely.io.fs import list_files
|
|
75
|
+
from supervisely.nn.experiments import ExperimentInfo
|
|
72
76
|
from supervisely.nn.inference.cache import InferenceImageCache
|
|
73
77
|
from supervisely.nn.inference.inference_request import (
|
|
74
78
|
InferenceRequest,
|
|
75
79
|
InferenceRequestsManager,
|
|
76
80
|
)
|
|
77
81
|
from supervisely.nn.inference.uploader import Uploader
|
|
78
|
-
from supervisely.nn.model.model_api import Prediction
|
|
82
|
+
from supervisely.nn.model.model_api import ModelAPI, Prediction
|
|
79
83
|
from supervisely.nn.prediction_dto import Prediction as PredictionDTO
|
|
80
84
|
from supervisely.nn.utils import (
|
|
81
85
|
CheckpointInfo,
|
|
@@ -93,7 +97,18 @@ from supervisely.project.project_meta import ProjectMeta
|
|
|
93
97
|
from supervisely.sly_logger import logger
|
|
94
98
|
from supervisely.task.progress import Progress
|
|
95
99
|
from supervisely.video.video import ALLOWED_VIDEO_EXTENSIONS, VideoFrameReader
|
|
96
|
-
from supervisely.
|
|
100
|
+
from supervisely.video_annotation.frame import Frame
|
|
101
|
+
from supervisely.video_annotation.frame_collection import FrameCollection
|
|
102
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
103
|
+
from supervisely.video_annotation.video_figure import VideoFigure
|
|
104
|
+
from supervisely.video_annotation.video_object import VideoObject
|
|
105
|
+
from supervisely.video_annotation.video_object_collection import VideoObjectCollection
|
|
106
|
+
from supervisely.video_annotation.video_tag_collection import VideoTagCollection
|
|
107
|
+
from supervisely.video_annotation.key_id_map import KeyIdMap
|
|
108
|
+
from supervisely.video_annotation.video_object_collection import (
|
|
109
|
+
VideoObject,
|
|
110
|
+
VideoObjectCollection,
|
|
111
|
+
)
|
|
97
112
|
|
|
98
113
|
try:
|
|
99
114
|
from typing import Literal
|
|
@@ -140,6 +155,7 @@ class Inference:
|
|
|
140
155
|
"""Default batch size for inference"""
|
|
141
156
|
INFERENCE_SETTINGS: str = None
|
|
142
157
|
"""Path to file with custom inference settings"""
|
|
158
|
+
DEFAULT_IOU_MERGE_THRESHOLD: float = 0.9
|
|
143
159
|
|
|
144
160
|
def __init__(
|
|
145
161
|
self,
|
|
@@ -193,7 +209,6 @@ class Inference:
|
|
|
193
209
|
self._task_id = None
|
|
194
210
|
self._sliding_window_mode = sliding_window_mode
|
|
195
211
|
self._autostart_delay_time = 5 * 60 # 5 min
|
|
196
|
-
self._tracker = None
|
|
197
212
|
self._hardware: str = None
|
|
198
213
|
if custom_inference_settings is None:
|
|
199
214
|
if self.INFERENCE_SETTINGS is not None:
|
|
@@ -383,7 +398,7 @@ class Inference:
|
|
|
383
398
|
if m_name and m_name.lower() == model.lower():
|
|
384
399
|
return m
|
|
385
400
|
return None
|
|
386
|
-
|
|
401
|
+
|
|
387
402
|
runtime = get_runtime(runtime)
|
|
388
403
|
logger.debug(f"Runtime: {runtime}")
|
|
389
404
|
|
|
@@ -427,7 +442,7 @@ class Inference:
|
|
|
427
442
|
|
|
428
443
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
429
444
|
except Exception as e:
|
|
430
|
-
logger.
|
|
445
|
+
logger.warning(
|
|
431
446
|
f"Device auto detection failed, set to default 'cpu', reason: {repr(e)}"
|
|
432
447
|
)
|
|
433
448
|
device = "cpu"
|
|
@@ -863,13 +878,57 @@ class Inference:
|
|
|
863
878
|
self.gui.download_progress.hide()
|
|
864
879
|
return local_model_files
|
|
865
880
|
|
|
881
|
+
def _fallback_download_custom_model_pt(self, deploy_params: dict):
|
|
882
|
+
"""
|
|
883
|
+
Downloads the PyTorch checkpoint from Team Files if TensorRT is failed to load.
|
|
884
|
+
"""
|
|
885
|
+
team_id = sly_env.team_id()
|
|
886
|
+
|
|
887
|
+
checkpoint_name = sly_fs.get_file_name(deploy_params["model_files"]["checkpoint"])
|
|
888
|
+
artifacts_dir = deploy_params["model_info"]["artifacts_dir"]
|
|
889
|
+
checkpoints_dir = os.path.join(artifacts_dir, "checkpoints")
|
|
890
|
+
checkpoint_ext = sly_fs.get_file_ext(deploy_params["model_info"]["checkpoints"][0])
|
|
891
|
+
|
|
892
|
+
pt_checkpoint_name = f"{checkpoint_name}{checkpoint_ext}"
|
|
893
|
+
remote_checkpoint_path = os.path.join(checkpoints_dir, pt_checkpoint_name)
|
|
894
|
+
local_checkpoint_path = os.path.join(self.model_dir, pt_checkpoint_name)
|
|
895
|
+
|
|
896
|
+
file_info = self.api.file.get_info_by_path(team_id, remote_checkpoint_path)
|
|
897
|
+
file_size = file_info.sizeb
|
|
898
|
+
if self.gui is not None:
|
|
899
|
+
with self.gui.download_progress(
|
|
900
|
+
message=f"Fallback. Downloading PyTorch checkpoint: '{pt_checkpoint_name}'",
|
|
901
|
+
total=file_size,
|
|
902
|
+
unit="bytes",
|
|
903
|
+
unit_scale=True,
|
|
904
|
+
) as download_pbar:
|
|
905
|
+
self.gui.download_progress.show()
|
|
906
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path, progress_cb=download_pbar.update)
|
|
907
|
+
self.gui.download_progress.hide()
|
|
908
|
+
else:
|
|
909
|
+
self.api.file.download(team_id, remote_checkpoint_path, local_checkpoint_path)
|
|
910
|
+
|
|
911
|
+
return local_checkpoint_path
|
|
912
|
+
|
|
913
|
+
def _remove_exported_checkpoints(self, checkpoint_path: str):
|
|
914
|
+
"""
|
|
915
|
+
Removes the exported checkpoints for provided PyTorch checkpoint path.
|
|
916
|
+
"""
|
|
917
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
918
|
+
onnx_path = checkpoint_path.replace(checkpoint_ext, ".onnx")
|
|
919
|
+
engine_path = checkpoint_path.replace(checkpoint_ext, ".engine")
|
|
920
|
+
if os.path.exists(onnx_path):
|
|
921
|
+
sly_fs.silent_remove(onnx_path)
|
|
922
|
+
if os.path.exists(engine_path):
|
|
923
|
+
sly_fs.silent_remove(engine_path)
|
|
924
|
+
|
|
866
925
|
def _download_custom_model(self, model_files: dict, log_progress: bool = True):
|
|
867
926
|
"""
|
|
868
927
|
Downloads the custom model data.
|
|
869
928
|
"""
|
|
870
929
|
team_id = sly_env.team_id()
|
|
871
930
|
local_model_files = {}
|
|
872
|
-
|
|
931
|
+
|
|
873
932
|
# Sort files to download 'checkpoint' first
|
|
874
933
|
files_order = sorted(model_files.keys(), key=lambda x: (0 if x == "checkpoint" else 1, x))
|
|
875
934
|
for file in files_order:
|
|
@@ -905,17 +964,23 @@ class Inference:
|
|
|
905
964
|
if extracted_files:
|
|
906
965
|
local_model_files[file] = file_path
|
|
907
966
|
return local_model_files
|
|
967
|
+
except _pickle.UnpicklingError as e:
|
|
968
|
+
# TODO: raise error - checkpoint is corrupted
|
|
969
|
+
logger.warning(f"Couldn't load '{file_name}'. Checkpoint might be corrupted. Error: {repr(e)}")
|
|
970
|
+
logger.warning("Model files will be downloaded from Team Files")
|
|
971
|
+
local_model_files[file] = file_path
|
|
972
|
+
continue
|
|
908
973
|
except Exception as e:
|
|
909
|
-
logger.
|
|
910
|
-
logger.
|
|
974
|
+
logger.warning(f"Failed to process checkpoint '{file_name}' to extract auxiliary files: {repr(e)}")
|
|
975
|
+
logger.warning("Model files will be downloaded from Team Files")
|
|
911
976
|
local_model_files[file] = file_path
|
|
912
977
|
continue
|
|
913
|
-
|
|
978
|
+
|
|
914
979
|
local_model_files[file] = file_path
|
|
915
980
|
if log_progress:
|
|
916
981
|
self.gui.download_progress.hide()
|
|
917
982
|
return local_model_files
|
|
918
|
-
|
|
983
|
+
|
|
919
984
|
def _get_deploy_parameters_from_custom_checkpoint(self, checkpoint_path: str, device: str, runtime: str) -> dict:
|
|
920
985
|
def _read_experiment_info(artifacts_dir: str) -> Optional[dict]:
|
|
921
986
|
exp_path = os.path.join(artifacts_dir, "experiment_info.json")
|
|
@@ -976,8 +1041,7 @@ class Inference:
|
|
|
976
1041
|
# --- LOCAL ---
|
|
977
1042
|
try:
|
|
978
1043
|
logger.debug("Reading state dict...")
|
|
979
|
-
|
|
980
|
-
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
|
1044
|
+
ckpt = torch_load_safe(checkpoint_path)
|
|
981
1045
|
model_info = ckpt.get("model_info", {})
|
|
982
1046
|
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
983
1047
|
model_files["checkpoint"] = checkpoint_path
|
|
@@ -1017,10 +1081,8 @@ class Inference:
|
|
|
1017
1081
|
if file_ext not in (".pth", ".pt"):
|
|
1018
1082
|
return extracted_files
|
|
1019
1083
|
|
|
1020
|
-
import torch # pylint: disable=import-error
|
|
1021
1084
|
logger.debug(f"Reading checkpoint: {checkpoint_path}")
|
|
1022
|
-
checkpoint =
|
|
1023
|
-
|
|
1085
|
+
checkpoint = torch_load_safe(checkpoint_path)
|
|
1024
1086
|
# 1. Extract additional model files embedded into checkpoint (if any)
|
|
1025
1087
|
ckpt_files = checkpoint.get("model_files", None)
|
|
1026
1088
|
if ckpt_files and isinstance(ckpt_files, dict):
|
|
@@ -1057,7 +1119,41 @@ class Inference:
|
|
|
1057
1119
|
self.runtime = deploy_params.get("runtime", RuntimeType.PYTORCH)
|
|
1058
1120
|
self.model_precision = deploy_params.get("model_precision", ModelPrecision.FP32)
|
|
1059
1121
|
self._hardware = get_hardware_info(self.device)
|
|
1060
|
-
|
|
1122
|
+
|
|
1123
|
+
model_files = deploy_params.get("model_files", None)
|
|
1124
|
+
if model_files is not None:
|
|
1125
|
+
checkpoint_path = deploy_params["model_files"]["checkpoint"]
|
|
1126
|
+
checkpoint_ext = sly_fs.get_file_ext(checkpoint_path)
|
|
1127
|
+
if self.runtime == RuntimeType.TENSORRT and checkpoint_ext == ".engine":
|
|
1128
|
+
try:
|
|
1129
|
+
self.load_model(**deploy_params)
|
|
1130
|
+
except Exception as e:
|
|
1131
|
+
logger.warning(
|
|
1132
|
+
f"Failed to load model with TensorRT. Downloading PyTorch to export to TensorRT. Error: {repr(e)}"
|
|
1133
|
+
)
|
|
1134
|
+
checkpoint_path = self._fallback_download_custom_model_pt(deploy_params)
|
|
1135
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1136
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1137
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1138
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1139
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1140
|
+
self.load_model(**deploy_params)
|
|
1141
|
+
if checkpoint_ext in (".pt", ".pth") and not self.runtime == RuntimeType.PYTORCH:
|
|
1142
|
+
if self.runtime == RuntimeType.ONNXRUNTIME:
|
|
1143
|
+
logger.info("Exporting PyTorch model to ONNX...")
|
|
1144
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1145
|
+
checkpoint_path = self.export_onnx(deploy_params)
|
|
1146
|
+
elif self.runtime == RuntimeType.TENSORRT:
|
|
1147
|
+
logger.info("Exporting PyTorch model to TensorRT...")
|
|
1148
|
+
self._remove_exported_checkpoints(checkpoint_path)
|
|
1149
|
+
checkpoint_path = self.export_tensorrt(deploy_params)
|
|
1150
|
+
deploy_params["model_files"]["checkpoint"] = checkpoint_path
|
|
1151
|
+
self.load_model(**deploy_params)
|
|
1152
|
+
else:
|
|
1153
|
+
self.load_model(**deploy_params)
|
|
1154
|
+
else:
|
|
1155
|
+
self.load_model(**deploy_params)
|
|
1156
|
+
|
|
1061
1157
|
self._model_served = True
|
|
1062
1158
|
self._deploy_params = deploy_params
|
|
1063
1159
|
if self._task_id is not None and is_production():
|
|
@@ -1159,6 +1255,8 @@ class Inference:
|
|
|
1159
1255
|
if model_source == ModelSource.CUSTOM:
|
|
1160
1256
|
self._set_model_meta_custom_model(model_info)
|
|
1161
1257
|
self._set_checkpoint_info_custom_model(deploy_params)
|
|
1258
|
+
elif model_source == ModelSource.PRETRAINED:
|
|
1259
|
+
self._set_checkpoint_info_pretrained(deploy_params)
|
|
1162
1260
|
|
|
1163
1261
|
try:
|
|
1164
1262
|
if is_production():
|
|
@@ -1232,6 +1330,19 @@ class Inference:
|
|
|
1232
1330
|
model_source=ModelSource.CUSTOM,
|
|
1233
1331
|
)
|
|
1234
1332
|
|
|
1333
|
+
def _set_checkpoint_info_pretrained(self, deploy_params: dict):
|
|
1334
|
+
checkpoint_name = os.path.basename(deploy_params["model_files"]["checkpoint"])
|
|
1335
|
+
model_name = _get_model_name(deploy_params["model_info"])
|
|
1336
|
+
checkpoint_url = deploy_params["model_info"]["meta"]["model_files"]["checkpoint"]
|
|
1337
|
+
model_source = ModelSource.PRETRAINED
|
|
1338
|
+
self.checkpoint_info = CheckpointInfo(
|
|
1339
|
+
checkpoint_name=checkpoint_name,
|
|
1340
|
+
model_name=model_name,
|
|
1341
|
+
architecture=self.FRAMEWORK_NAME,
|
|
1342
|
+
checkpoint_url=checkpoint_url,
|
|
1343
|
+
model_source=model_source,
|
|
1344
|
+
)
|
|
1345
|
+
|
|
1235
1346
|
def shutdown_model(self):
|
|
1236
1347
|
self._model_served = False
|
|
1237
1348
|
self._model_frozen = False
|
|
@@ -1252,6 +1363,26 @@ class Inference:
|
|
|
1252
1363
|
def get_classes(self) -> List[str]:
|
|
1253
1364
|
return self.classes
|
|
1254
1365
|
|
|
1366
|
+
def _tracker_init(self, tracker: str, tracker_settings: dict):
|
|
1367
|
+
# Check if tracking is supported for this model
|
|
1368
|
+
info = self.get_info()
|
|
1369
|
+
tracking_support = info.get("tracking_on_videos_support", False)
|
|
1370
|
+
|
|
1371
|
+
if not tracking_support:
|
|
1372
|
+
logger.debug("Tracking is not supported for this model")
|
|
1373
|
+
return None
|
|
1374
|
+
|
|
1375
|
+
if tracker == "botsort":
|
|
1376
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1377
|
+
|
|
1378
|
+
device = tracker_settings.get("device", self.device)
|
|
1379
|
+
logger.debug(f"Initializing BotSort tracker with device: {device}")
|
|
1380
|
+
return BotSortTracker(settings=tracker_settings, device=device)
|
|
1381
|
+
else:
|
|
1382
|
+
if tracker is not None:
|
|
1383
|
+
logger.warning(f"Unknown tracking type: {tracker}. Tracking is disabled.")
|
|
1384
|
+
return None
|
|
1385
|
+
|
|
1255
1386
|
def get_info(self) -> Dict[str, Any]:
|
|
1256
1387
|
num_classes = None
|
|
1257
1388
|
classes = None
|
|
@@ -1260,15 +1391,15 @@ class Inference:
|
|
|
1260
1391
|
if classes is not None:
|
|
1261
1392
|
num_classes = len(classes)
|
|
1262
1393
|
except NotImplementedError:
|
|
1263
|
-
logger.
|
|
1394
|
+
logger.warning(f"get_classes() function not implemented for {type(self)} object.")
|
|
1264
1395
|
except AttributeError:
|
|
1265
|
-
logger.
|
|
1396
|
+
logger.warning("Probably, get_classes() function not working without model deploy.")
|
|
1266
1397
|
except Exception as exc:
|
|
1267
|
-
logger.
|
|
1398
|
+
logger.warning("Unknown exception. Please, contact support")
|
|
1268
1399
|
logger.exception(exc)
|
|
1269
1400
|
|
|
1270
1401
|
if num_classes is None:
|
|
1271
|
-
logger.
|
|
1402
|
+
logger.warning(f"get_classes() function return {classes}; skip classes processing.")
|
|
1272
1403
|
|
|
1273
1404
|
return {
|
|
1274
1405
|
"app_name": get_name_from_env(default="Neural Network Serving"),
|
|
@@ -1277,15 +1408,51 @@ class Inference:
|
|
|
1277
1408
|
"sliding_window_support": self.sliding_window_mode,
|
|
1278
1409
|
"videos_support": True,
|
|
1279
1410
|
"async_video_inference_support": True,
|
|
1280
|
-
"tracking_on_videos_support":
|
|
1411
|
+
"tracking_on_videos_support": False,
|
|
1281
1412
|
"async_image_inference_support": True,
|
|
1282
|
-
"tracking_algorithms": ["
|
|
1413
|
+
"tracking_algorithms": ["botsort"],
|
|
1283
1414
|
"batch_inference_support": self.is_batch_inference_supported(),
|
|
1284
1415
|
"max_batch_size": self.max_batch_size,
|
|
1285
1416
|
}
|
|
1286
1417
|
|
|
1287
1418
|
# pylint: enable=method-hidden
|
|
1288
1419
|
|
|
1420
|
+
def get_tracking_settings(self) -> Dict[str, Dict[str, Any]]:
|
|
1421
|
+
"""
|
|
1422
|
+
Get default parameters for all available tracking algorithms.
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
{"botsort": {"track_high_thresh": 0.6, ...}}
|
|
1426
|
+
Empty dict if tracking not supported.
|
|
1427
|
+
"""
|
|
1428
|
+
info = self.get_info()
|
|
1429
|
+
trackers_params = {}
|
|
1430
|
+
|
|
1431
|
+
tracking_support = info.get("tracking_on_videos_support")
|
|
1432
|
+
if not tracking_support:
|
|
1433
|
+
return trackers_params
|
|
1434
|
+
|
|
1435
|
+
tracking_algorithms = info.get("tracking_algorithms", [])
|
|
1436
|
+
|
|
1437
|
+
for tracker_name in tracking_algorithms:
|
|
1438
|
+
try:
|
|
1439
|
+
if tracker_name == "botsort":
|
|
1440
|
+
from supervisely.nn.tracker import BotSortTracker
|
|
1441
|
+
|
|
1442
|
+
trackers_params[tracker_name] = BotSortTracker.get_default_params()
|
|
1443
|
+
# Add other trackers here as elif blocks
|
|
1444
|
+
else:
|
|
1445
|
+
logger.debug(f"Tracker '{tracker_name}' not implemented")
|
|
1446
|
+
except Exception as e:
|
|
1447
|
+
logger.warning(f"Failed to get params for '{tracker_name}': {e}")
|
|
1448
|
+
|
|
1449
|
+
INTERNAL_FIELDS = {"device", "fps"}
|
|
1450
|
+
for tracker_name, params in trackers_params.items():
|
|
1451
|
+
trackers_params[tracker_name] = {
|
|
1452
|
+
k: v for k, v in params.items() if k not in INTERNAL_FIELDS
|
|
1453
|
+
}
|
|
1454
|
+
return trackers_params
|
|
1455
|
+
|
|
1289
1456
|
def get_human_readable_info(self, replace_none_with: Optional[str] = None):
|
|
1290
1457
|
hr_info = {}
|
|
1291
1458
|
info = self.get_info()
|
|
@@ -1401,8 +1568,12 @@ class Inference:
|
|
|
1401
1568
|
# for example empty mask
|
|
1402
1569
|
continue
|
|
1403
1570
|
if isinstance(label, list):
|
|
1571
|
+
for lb in label:
|
|
1572
|
+
lb.status = LabelingStatus.AUTO
|
|
1404
1573
|
labels.extend(label)
|
|
1405
1574
|
continue
|
|
1575
|
+
|
|
1576
|
+
label.status = LabelingStatus.AUTO
|
|
1406
1577
|
labels.append(label)
|
|
1407
1578
|
|
|
1408
1579
|
# create annotation with correct image resolution
|
|
@@ -1447,7 +1618,7 @@ class Inference:
|
|
|
1447
1618
|
if api is None:
|
|
1448
1619
|
api = self.api
|
|
1449
1620
|
return api
|
|
1450
|
-
|
|
1621
|
+
|
|
1451
1622
|
def _inference_auto(
|
|
1452
1623
|
self,
|
|
1453
1624
|
source: List[Union[str, np.ndarray]],
|
|
@@ -1833,24 +2004,12 @@ class Inference:
|
|
|
1833
2004
|
else:
|
|
1834
2005
|
n_frames = frames_reader.frames_count()
|
|
1835
2006
|
|
|
1836
|
-
|
|
1837
|
-
from supervisely.nn.tracker import BoTTracker
|
|
1838
|
-
|
|
1839
|
-
tracker = BoTTracker(state)
|
|
1840
|
-
elif tracking == "deepsort":
|
|
1841
|
-
from supervisely.nn.tracker import DeepSortTracker
|
|
1842
|
-
|
|
1843
|
-
tracker = DeepSortTracker(state)
|
|
1844
|
-
else:
|
|
1845
|
-
if tracking is not None:
|
|
1846
|
-
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
1847
|
-
tracker = None
|
|
2007
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
1848
2008
|
|
|
1849
2009
|
progress_total = (n_frames + step - 1) // step
|
|
1850
2010
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
1851
2011
|
|
|
1852
2012
|
results = []
|
|
1853
|
-
tracks_data = {}
|
|
1854
2013
|
for batch in batched(
|
|
1855
2014
|
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
1856
2015
|
batch_size,
|
|
@@ -1870,28 +2029,30 @@ class Inference:
|
|
|
1870
2029
|
source=frames,
|
|
1871
2030
|
settings=inference_settings,
|
|
1872
2031
|
)
|
|
2032
|
+
|
|
2033
|
+
if inference_request.tracker is not None:
|
|
2034
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2035
|
+
|
|
1873
2036
|
predictions = [
|
|
1874
2037
|
Prediction(ann, model_meta=self.model_meta, frame_index=frame_index)
|
|
1875
2038
|
for ann, frame_index in zip(anns, batch)
|
|
1876
2039
|
]
|
|
2040
|
+
|
|
1877
2041
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
1878
2042
|
pred.extra_data["slides_data"] = this_slides_data
|
|
1879
2043
|
batch_results = self._format_output(predictions)
|
|
1880
|
-
|
|
1881
|
-
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
1882
|
-
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
2044
|
+
|
|
1883
2045
|
inference_request.add_results(batch_results)
|
|
1884
2046
|
inference_request.done(len(batch_results))
|
|
1885
2047
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
1886
2048
|
video_ann_json = None
|
|
1887
|
-
if tracker is not None:
|
|
2049
|
+
if inference_request.tracker is not None:
|
|
1888
2050
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
1889
|
-
video_ann_json = tracker.
|
|
1890
|
-
tracks_data, (video_height, video_witdth), n_frames
|
|
1891
|
-
).to_json()
|
|
2051
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
1892
2052
|
inference_request.done()
|
|
1893
2053
|
result = {"ann": results, "video_ann": video_ann_json}
|
|
1894
2054
|
inference_request.final_result = result.copy()
|
|
2055
|
+
return video_ann_json
|
|
1895
2056
|
|
|
1896
2057
|
def _inference_image_ids(
|
|
1897
2058
|
self,
|
|
@@ -1915,10 +2076,11 @@ class Inference:
|
|
|
1915
2076
|
raise ValueError("Image ids are not provided")
|
|
1916
2077
|
if not isinstance(image_ids, list):
|
|
1917
2078
|
image_ids = [image_ids]
|
|
2079
|
+
model_prediction_suffix = state.get("model_prediction_suffix", None)
|
|
1918
2080
|
upload_mode = state.get("upload_mode", None)
|
|
1919
2081
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
1920
2082
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
1921
|
-
iou_merge_threshold = 0.
|
|
2083
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD # TODO: change to 0.9
|
|
1922
2084
|
|
|
1923
2085
|
images_infos = api.image.get_info_by_id_batch(image_ids)
|
|
1924
2086
|
images_infos_dict = {im_info.id: im_info for im_info in images_infos}
|
|
@@ -1979,6 +2141,7 @@ class Inference:
|
|
|
1979
2141
|
progress_cb=inference_request.done,
|
|
1980
2142
|
iou_merge_threshold=iou_merge_threshold,
|
|
1981
2143
|
inference_request=inference_request,
|
|
2144
|
+
model_prediction_suffix=model_prediction_suffix,
|
|
1982
2145
|
)
|
|
1983
2146
|
|
|
1984
2147
|
_add_results_to_request = partial(
|
|
@@ -1994,8 +2157,8 @@ class Inference:
|
|
|
1994
2157
|
with Uploader(upload_f, logger=logger) as uploader:
|
|
1995
2158
|
for image_ids_batch in batched(image_ids, batch_size=batch_size):
|
|
1996
2159
|
if uploader.has_exception():
|
|
1997
|
-
exception = uploader.exception
|
|
1998
|
-
raise
|
|
2160
|
+
exception = uploader.exception
|
|
2161
|
+
raise exception
|
|
1999
2162
|
if inference_request.is_stopped():
|
|
2000
2163
|
logger.debug(
|
|
2001
2164
|
f"Cancelling inference project...",
|
|
@@ -2039,7 +2202,7 @@ class Inference:
|
|
|
2039
2202
|
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2040
2203
|
if video_id is None:
|
|
2041
2204
|
raise ValueError("Video id is not provided")
|
|
2042
|
-
video_info = api.video.get_info_by_id(video_id)
|
|
2205
|
+
video_info = api.video.get_info_by_id(video_id, force_metadata_for_links=True)
|
|
2043
2206
|
start_frame_index = get_value_for_keys(
|
|
2044
2207
|
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2045
2208
|
)
|
|
@@ -2069,18 +2232,8 @@ class Inference:
|
|
|
2069
2232
|
else:
|
|
2070
2233
|
n_frames = video_info.frames_count
|
|
2071
2234
|
|
|
2072
|
-
|
|
2073
|
-
from supervisely.nn.tracker import BoTTracker
|
|
2235
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2074
2236
|
|
|
2075
|
-
tracker = BoTTracker(state)
|
|
2076
|
-
elif tracking == "deepsort":
|
|
2077
|
-
from supervisely.nn.tracker import DeepSortTracker
|
|
2078
|
-
|
|
2079
|
-
tracker = DeepSortTracker(state)
|
|
2080
|
-
else:
|
|
2081
|
-
if tracking is not None:
|
|
2082
|
-
logger.warning(f"Unknown tracking type: {tracking}. Tracking is disabled.")
|
|
2083
|
-
tracker = None
|
|
2084
2237
|
logger.debug(
|
|
2085
2238
|
f"Video info:",
|
|
2086
2239
|
extra=dict(
|
|
@@ -2097,7 +2250,6 @@ class Inference:
|
|
|
2097
2250
|
progress_total = (n_frames + step - 1) // step
|
|
2098
2251
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2099
2252
|
|
|
2100
|
-
tracks_data = {}
|
|
2101
2253
|
for batch in batched(
|
|
2102
2254
|
range(start_frame_index, start_frame_index + direction * n_frames, direction * step),
|
|
2103
2255
|
batch_size,
|
|
@@ -2116,6 +2268,10 @@ class Inference:
|
|
|
2116
2268
|
source=frames,
|
|
2117
2269
|
settings=inference_settings,
|
|
2118
2270
|
)
|
|
2271
|
+
|
|
2272
|
+
if inference_request.tracker is not None:
|
|
2273
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2274
|
+
|
|
2119
2275
|
predictions = [
|
|
2120
2276
|
Prediction(
|
|
2121
2277
|
ann,
|
|
@@ -2130,20 +2286,173 @@ class Inference:
|
|
|
2130
2286
|
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2131
2287
|
pred.extra_data["slides_data"] = this_slides_data
|
|
2132
2288
|
batch_results = self._format_output(predictions)
|
|
2133
|
-
|
|
2134
|
-
for frame_index, frame, ann in zip(batch, frames, anns):
|
|
2135
|
-
tracks_data = tracker.update(frame, ann, frame_index, tracks_data)
|
|
2289
|
+
|
|
2136
2290
|
inference_request.add_results(batch_results)
|
|
2137
2291
|
inference_request.done(len(batch_results))
|
|
2138
2292
|
logger.debug(f"Frames {batch[0]}-{batch[-1]} done.")
|
|
2139
2293
|
video_ann_json = None
|
|
2140
|
-
if tracker is not None:
|
|
2294
|
+
if inference_request.tracker is not None:
|
|
2141
2295
|
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2142
|
-
video_ann_json = tracker.
|
|
2143
|
-
tracks_data, (video_info.frame_height, video_info.frame_width), n_frames
|
|
2144
|
-
).to_json()
|
|
2296
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2145
2297
|
inference_request.done()
|
|
2146
2298
|
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2299
|
+
return video_ann_json
|
|
2300
|
+
|
|
2301
|
+
def _tracking_by_detection(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2302
|
+
logger.debug("Inferring video_id...", extra={"state": state})
|
|
2303
|
+
inference_settings = self._get_inference_settings(state)
|
|
2304
|
+
logger.debug(f"Inference settings:", extra=inference_settings)
|
|
2305
|
+
batch_size = self._get_batch_size_from_state(state)
|
|
2306
|
+
video_id = get_value_for_keys(state, ["videoId", "video_id"], ignore_none=True)
|
|
2307
|
+
if video_id is None:
|
|
2308
|
+
raise ValueError("Video id is not provided")
|
|
2309
|
+
video_info = api.video.get_info_by_id(video_id)
|
|
2310
|
+
start_frame_index = get_value_for_keys(
|
|
2311
|
+
state, ["startFrameIndex", "start_frame_index", "start_frame"], ignore_none=True
|
|
2312
|
+
)
|
|
2313
|
+
if start_frame_index is None:
|
|
2314
|
+
start_frame_index = 0
|
|
2315
|
+
step = get_value_for_keys(state, ["stride", "step"], ignore_none=True)
|
|
2316
|
+
if step is None:
|
|
2317
|
+
step = 1
|
|
2318
|
+
end_frame_index = get_value_for_keys(
|
|
2319
|
+
state, ["endFrameIndex", "end_frame_index", "end_frame"], ignore_none=True
|
|
2320
|
+
)
|
|
2321
|
+
duration = state.get("duration", None)
|
|
2322
|
+
frames_count = get_value_for_keys(
|
|
2323
|
+
state, ["framesCount", "frames_count", "num_frames"], ignore_none=True
|
|
2324
|
+
)
|
|
2325
|
+
tracking = state.get("tracker", None)
|
|
2326
|
+
direction = state.get("direction", "forward")
|
|
2327
|
+
direction = 1 if direction == "forward" else -1
|
|
2328
|
+
track_id = get_value_for_keys(state, ["trackId", "track_id"], ignore_none=True)
|
|
2329
|
+
|
|
2330
|
+
if frames_count is not None:
|
|
2331
|
+
n_frames = frames_count
|
|
2332
|
+
elif end_frame_index is not None:
|
|
2333
|
+
n_frames = end_frame_index - start_frame_index
|
|
2334
|
+
elif duration is not None:
|
|
2335
|
+
fps = video_info.frames_count / video_info.duration
|
|
2336
|
+
n_frames = int(duration * fps)
|
|
2337
|
+
else:
|
|
2338
|
+
n_frames = video_info.frames_count
|
|
2339
|
+
|
|
2340
|
+
inference_request.tracker = self._tracker_init(state.get("tracker", None), state.get("tracker_settings", {}))
|
|
2341
|
+
|
|
2342
|
+
logger.debug(
|
|
2343
|
+
f"Video info:",
|
|
2344
|
+
extra=dict(
|
|
2345
|
+
w=video_info.frame_width,
|
|
2346
|
+
h=video_info.frame_height,
|
|
2347
|
+
start_frame_index=start_frame_index,
|
|
2348
|
+
n_frames=n_frames,
|
|
2349
|
+
),
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
# start downloading video in background
|
|
2353
|
+
self.cache.run_cache_task_manually(api, None, video_id=video_id)
|
|
2354
|
+
|
|
2355
|
+
progress_total = (n_frames + step - 1) // step
|
|
2356
|
+
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, progress_total)
|
|
2357
|
+
|
|
2358
|
+
_upload_f = partial(
|
|
2359
|
+
self.upload_predictions_to_video,
|
|
2360
|
+
api=api,
|
|
2361
|
+
video_info=video_info,
|
|
2362
|
+
track_id=track_id,
|
|
2363
|
+
context=inference_request.context,
|
|
2364
|
+
progress_cb=inference_request.done,
|
|
2365
|
+
inference_request=inference_request,
|
|
2366
|
+
)
|
|
2367
|
+
|
|
2368
|
+
_range = (start_frame_index, start_frame_index + direction * n_frames)
|
|
2369
|
+
if _range[0] > _range[1]:
|
|
2370
|
+
_range = (_range[1], _range[0])
|
|
2371
|
+
|
|
2372
|
+
def _notify_f(predictions: List[Prediction]):
|
|
2373
|
+
logger.debug(
|
|
2374
|
+
"Notifying tracking progress...",
|
|
2375
|
+
extra={
|
|
2376
|
+
"track_id": track_id,
|
|
2377
|
+
"range": _range,
|
|
2378
|
+
"current": inference_request.progress.current,
|
|
2379
|
+
"total": inference_request.progress.total,
|
|
2380
|
+
},
|
|
2381
|
+
)
|
|
2382
|
+
stopped = self.api.video.notify_progress(
|
|
2383
|
+
track_id=track_id,
|
|
2384
|
+
video_id=video_info.id,
|
|
2385
|
+
frame_start=_range[0],
|
|
2386
|
+
frame_end=_range[1],
|
|
2387
|
+
current=inference_request.progress.current,
|
|
2388
|
+
total=inference_request.progress.total,
|
|
2389
|
+
)
|
|
2390
|
+
if stopped:
|
|
2391
|
+
inference_request.stop()
|
|
2392
|
+
logger.info("Tracking has been stopped by user", extra={"track_id": track_id})
|
|
2393
|
+
|
|
2394
|
+
def _exception_handler(e: Exception):
|
|
2395
|
+
self.api.video.notify_tracking_error(
|
|
2396
|
+
track_id=track_id,
|
|
2397
|
+
error=str(type(e)),
|
|
2398
|
+
message=str(e),
|
|
2399
|
+
)
|
|
2400
|
+
raise e
|
|
2401
|
+
|
|
2402
|
+
with Uploader(
|
|
2403
|
+
upload_f=_upload_f,
|
|
2404
|
+
notify_f=_notify_f,
|
|
2405
|
+
exception_handler=_exception_handler,
|
|
2406
|
+
logger=logger,
|
|
2407
|
+
) as uploader:
|
|
2408
|
+
for batch in batched(
|
|
2409
|
+
range(
|
|
2410
|
+
start_frame_index, start_frame_index + direction * n_frames, direction * step
|
|
2411
|
+
),
|
|
2412
|
+
batch_size,
|
|
2413
|
+
):
|
|
2414
|
+
if inference_request.is_stopped():
|
|
2415
|
+
logger.debug(
|
|
2416
|
+
f"Cancelling inference video...",
|
|
2417
|
+
extra={"inference_request_uuid": inference_request.uuid},
|
|
2418
|
+
)
|
|
2419
|
+
break
|
|
2420
|
+
logger.debug(
|
|
2421
|
+
f"Inferring frames {batch[0]}-{batch[-1]}:",
|
|
2422
|
+
)
|
|
2423
|
+
frames = self.cache.download_frames(
|
|
2424
|
+
api, video_info.id, batch, redownload_video=True
|
|
2425
|
+
)
|
|
2426
|
+
anns, slides_data = self._inference_auto(
|
|
2427
|
+
source=frames,
|
|
2428
|
+
settings=inference_settings,
|
|
2429
|
+
)
|
|
2430
|
+
|
|
2431
|
+
if inference_request.tracker is not None:
|
|
2432
|
+
anns = self._apply_tracker_to_anns(frames, anns, inference_request.tracker)
|
|
2433
|
+
|
|
2434
|
+
predictions = [
|
|
2435
|
+
Prediction(
|
|
2436
|
+
ann,
|
|
2437
|
+
model_meta=self.model_meta,
|
|
2438
|
+
frame_index=frame_index,
|
|
2439
|
+
video_id=video_info.id,
|
|
2440
|
+
dataset_id=video_info.dataset_id,
|
|
2441
|
+
project_id=video_info.project_id,
|
|
2442
|
+
)
|
|
2443
|
+
for ann, frame_index in zip(anns, batch)
|
|
2444
|
+
]
|
|
2445
|
+
for pred, this_slides_data in zip(predictions, slides_data):
|
|
2446
|
+
pred.extra_data["slides_data"] = this_slides_data
|
|
2447
|
+
uploader.put(predictions)
|
|
2448
|
+
video_ann_json = None
|
|
2449
|
+
if inference_request.tracker is not None:
|
|
2450
|
+
inference_request.set_stage("Postprocess...", 0, 1)
|
|
2451
|
+
video_ann_json = inference_request.tracker.video_annotation.to_json()
|
|
2452
|
+
inference_request.done()
|
|
2453
|
+
inference_request.final_result = {"video_ann": video_ann_json}
|
|
2454
|
+
return video_ann_json
|
|
2455
|
+
|
|
2147
2456
|
|
|
2148
2457
|
def _inference_project_id(self, api: Api, state: dict, inference_request: InferenceRequest):
|
|
2149
2458
|
"""Inference project images.
|
|
@@ -2161,10 +2470,12 @@ class Inference:
|
|
|
2161
2470
|
project_info = api.project.get_info_by_id(project_id)
|
|
2162
2471
|
if project_info.type != str(ProjectType.IMAGES):
|
|
2163
2472
|
raise ValueError("Only images projects are supported.")
|
|
2473
|
+
|
|
2474
|
+
model_prediction_suffix = state.get("model_prediction_suffix", None)
|
|
2164
2475
|
upload_mode = state.get("upload_mode", None)
|
|
2165
2476
|
iou_merge_threshold = inference_settings.get("existing_objects_iou_thresh", None)
|
|
2166
2477
|
if upload_mode == "iou_merge" and iou_merge_threshold is None:
|
|
2167
|
-
iou_merge_threshold =
|
|
2478
|
+
iou_merge_threshold = self.DEFAULT_IOU_MERGE_THRESHOLD
|
|
2168
2479
|
cache_project_on_model = state.get("cache_project_on_model", False)
|
|
2169
2480
|
|
|
2170
2481
|
project_info = api.project.get_info_by_id(project_id)
|
|
@@ -2235,6 +2546,7 @@ class Inference:
|
|
|
2235
2546
|
progress_cb=inference_request.done,
|
|
2236
2547
|
iou_merge_threshold=iou_merge_threshold,
|
|
2237
2548
|
inference_request=inference_request,
|
|
2549
|
+
model_prediction_suffix=model_prediction_suffix,
|
|
2238
2550
|
)
|
|
2239
2551
|
|
|
2240
2552
|
_add_results_to_request = partial(
|
|
@@ -2260,7 +2572,7 @@ class Inference:
|
|
|
2260
2572
|
return
|
|
2261
2573
|
if uploader.has_exception():
|
|
2262
2574
|
exception = uploader.exception
|
|
2263
|
-
raise
|
|
2575
|
+
raise exception
|
|
2264
2576
|
if cache_project_on_model:
|
|
2265
2577
|
images_paths, _ = zip(
|
|
2266
2578
|
*read_from_cached_project(
|
|
@@ -2389,7 +2701,7 @@ class Inference:
|
|
|
2389
2701
|
return
|
|
2390
2702
|
if uploader.has_exception():
|
|
2391
2703
|
exception = uploader.exception
|
|
2392
|
-
raise
|
|
2704
|
+
raise exception
|
|
2393
2705
|
if i == num_warmup:
|
|
2394
2706
|
inference_request.set_stage(InferenceRequest.Stage.INFERENCE, 0, num_iterations)
|
|
2395
2707
|
|
|
@@ -2454,6 +2766,7 @@ class Inference:
|
|
|
2454
2766
|
# raise DialogWindowError(title="Call undeployed model.", description=msg)
|
|
2455
2767
|
raise RuntimeError(msg)
|
|
2456
2768
|
return func(*args, **kwargs)
|
|
2769
|
+
|
|
2457
2770
|
return wrapper
|
|
2458
2771
|
|
|
2459
2772
|
def _freeze_model(self):
|
|
@@ -2498,7 +2811,6 @@ class Inference:
|
|
|
2498
2811
|
timer.daemon = True
|
|
2499
2812
|
timer.start()
|
|
2500
2813
|
self._freeze_timer = timer
|
|
2501
|
-
logger.debug("Model will be frozen in %s seconds due to inactivity.", self._inactivity_timeout)
|
|
2502
2814
|
|
|
2503
2815
|
def _set_served_callback(self):
|
|
2504
2816
|
self._model_served = True
|
|
@@ -2605,11 +2917,16 @@ class Inference:
|
|
|
2605
2917
|
progress_cb=None,
|
|
2606
2918
|
iou_merge_threshold: float = None,
|
|
2607
2919
|
inference_request: InferenceRequest = None,
|
|
2920
|
+
model_prediction_suffix: str = None,
|
|
2608
2921
|
):
|
|
2609
2922
|
ds_predictions: Dict[int, List[Prediction]] = defaultdict(list)
|
|
2610
2923
|
for prediction in predictions:
|
|
2611
2924
|
ds_predictions[prediction.dataset_id].append(prediction)
|
|
2612
2925
|
|
|
2926
|
+
def update_labeling_status(ann: Annotation) -> Annotation:
|
|
2927
|
+
for label in ann.labels:
|
|
2928
|
+
label.status = LabelingStatus.AUTO
|
|
2929
|
+
|
|
2613
2930
|
def _new_name(image_info: ImageInfo):
|
|
2614
2931
|
name = Path(image_info.name)
|
|
2615
2932
|
stem = name.stem
|
|
@@ -2642,10 +2959,10 @@ class Inference:
|
|
|
2642
2959
|
context.setdefault("created_dataset", {})[src_dataset_id] = created_dataset.id
|
|
2643
2960
|
return created_dataset.id
|
|
2644
2961
|
|
|
2645
|
-
created_names = []
|
|
2646
2962
|
if context is None:
|
|
2647
2963
|
context = {}
|
|
2648
2964
|
for dataset_id, preds in ds_predictions.items():
|
|
2965
|
+
created_names = set()
|
|
2649
2966
|
if dst_project_id is not None:
|
|
2650
2967
|
# upload to the destination project
|
|
2651
2968
|
dst_dataset_id = _get_or_create_dataset(
|
|
@@ -2666,7 +2983,9 @@ class Inference:
|
|
|
2666
2983
|
meta_changed = False
|
|
2667
2984
|
for pred in preds:
|
|
2668
2985
|
ann = pred.annotation
|
|
2669
|
-
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2986
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
2987
|
+
project_meta, ann, model_prediction_suffix
|
|
2988
|
+
)
|
|
2670
2989
|
meta_changed = meta_changed or meta_changed_
|
|
2671
2990
|
pred.annotation = ann
|
|
2672
2991
|
prediction.model_meta = project_meta
|
|
@@ -2683,8 +3002,15 @@ class Inference:
|
|
|
2683
3002
|
iou=iou_merge_threshold,
|
|
2684
3003
|
meta=project_meta,
|
|
2685
3004
|
)
|
|
3005
|
+
|
|
3006
|
+
# Update labeling status of new predictions before upload
|
|
3007
|
+
anns_with_nn_flags = []
|
|
2686
3008
|
for pred, ann in zip(preds, anns):
|
|
3009
|
+
update_labeling_status(ann)
|
|
2687
3010
|
pred.annotation = ann
|
|
3011
|
+
anns_with_nn_flags.append(ann)
|
|
3012
|
+
|
|
3013
|
+
anns = anns_with_nn_flags
|
|
2688
3014
|
|
|
2689
3015
|
context.setdefault("image_info", {})
|
|
2690
3016
|
missing = [
|
|
@@ -2712,7 +3038,7 @@ class Inference:
|
|
|
2712
3038
|
with_annotations=False,
|
|
2713
3039
|
save_source_date=False,
|
|
2714
3040
|
)
|
|
2715
|
-
created_names.
|
|
3041
|
+
created_names.update([image_info.name for image_info in dst_image_infos])
|
|
2716
3042
|
api.annotation.upload_anns([image_info.id for image_info in dst_image_infos], anns)
|
|
2717
3043
|
else:
|
|
2718
3044
|
# upload to the source dataset
|
|
@@ -2730,7 +3056,9 @@ class Inference:
|
|
|
2730
3056
|
meta_changed = False
|
|
2731
3057
|
for pred in preds:
|
|
2732
3058
|
ann = pred.annotation
|
|
2733
|
-
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3059
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3060
|
+
project_meta, ann, model_prediction_suffix
|
|
3061
|
+
)
|
|
2734
3062
|
meta_changed = meta_changed or meta_changed_
|
|
2735
3063
|
pred.annotation = ann
|
|
2736
3064
|
prediction.model_meta = project_meta
|
|
@@ -2747,7 +3075,10 @@ class Inference:
|
|
|
2747
3075
|
iou=iou_merge_threshold,
|
|
2748
3076
|
meta=project_meta,
|
|
2749
3077
|
)
|
|
3078
|
+
|
|
3079
|
+
# Update labeling status of predicted labels before optional merge
|
|
2750
3080
|
for pred, ann in zip(preds, anns):
|
|
3081
|
+
update_labeling_status(ann)
|
|
2751
3082
|
pred.annotation = ann
|
|
2752
3083
|
|
|
2753
3084
|
if upload_mode in ["iou_merge", "append"]:
|
|
@@ -2789,6 +3120,83 @@ class Inference:
|
|
|
2789
3120
|
inference_request.add_results(results)
|
|
2790
3121
|
inference_request.done(len(results))
|
|
2791
3122
|
|
|
3123
|
+
def upload_predictions_to_video(
|
|
3124
|
+
self,
|
|
3125
|
+
predictions: List[Prediction],
|
|
3126
|
+
api: Api,
|
|
3127
|
+
video_info: VideoInfo,
|
|
3128
|
+
track_id: str,
|
|
3129
|
+
context: Dict,
|
|
3130
|
+
progress_cb=None,
|
|
3131
|
+
inference_request: InferenceRequest = None,
|
|
3132
|
+
):
|
|
3133
|
+
key_id_map = KeyIdMap()
|
|
3134
|
+
project_meta = context.get("project_meta", None)
|
|
3135
|
+
if project_meta is None:
|
|
3136
|
+
project_meta = ProjectMeta.from_json(api.project.get_meta(video_info.project_id))
|
|
3137
|
+
context["project_meta"] = project_meta
|
|
3138
|
+
meta_changed = False
|
|
3139
|
+
for prediction in predictions:
|
|
3140
|
+
project_meta, ann, meta_changed_ = update_meta_and_ann(
|
|
3141
|
+
project_meta, prediction.annotation, None
|
|
3142
|
+
)
|
|
3143
|
+
prediction.annotation = ann
|
|
3144
|
+
meta_changed = meta_changed or meta_changed_
|
|
3145
|
+
if meta_changed:
|
|
3146
|
+
project_meta = api.project.update_meta(video_info.project_id, project_meta)
|
|
3147
|
+
context["project_meta"] = project_meta
|
|
3148
|
+
|
|
3149
|
+
figure_data_by_object_id = defaultdict(list)
|
|
3150
|
+
|
|
3151
|
+
tracks_to_object_ids = context.setdefault("tracks_to_object_ids", {})
|
|
3152
|
+
new_tracks: Dict[int, VideoObject] = {}
|
|
3153
|
+
for prediction in predictions:
|
|
3154
|
+
annotation = prediction.annotation
|
|
3155
|
+
tracks = annotation.custom_data
|
|
3156
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3157
|
+
if track not in tracks_to_object_ids and track not in new_tracks:
|
|
3158
|
+
video_object = VideoObject(obj_class=label.obj_class)
|
|
3159
|
+
new_tracks[track] = video_object
|
|
3160
|
+
if new_tracks:
|
|
3161
|
+
tracks, video_objects = zip(*new_tracks.items())
|
|
3162
|
+
added_object_ids = api.video.object.append_bulk(
|
|
3163
|
+
video_info.id, VideoObjectCollection(video_objects), key_id_map=key_id_map
|
|
3164
|
+
)
|
|
3165
|
+
for track, object_id in zip(tracks, added_object_ids):
|
|
3166
|
+
tracks_to_object_ids[track] = object_id
|
|
3167
|
+
for prediction in predictions:
|
|
3168
|
+
annotation = prediction.annotation
|
|
3169
|
+
tracks = annotation.custom_data
|
|
3170
|
+
for track, label in zip(tracks, annotation.labels):
|
|
3171
|
+
object_id = tracks_to_object_ids[track]
|
|
3172
|
+
figure_data_by_object_id[object_id].append(
|
|
3173
|
+
{
|
|
3174
|
+
ApiField.OBJECT_ID: object_id,
|
|
3175
|
+
ApiField.GEOMETRY_TYPE: label.geometry.geometry_name(),
|
|
3176
|
+
ApiField.GEOMETRY: label.geometry.to_json(),
|
|
3177
|
+
ApiField.META: {ApiField.FRAME: prediction.frame_index},
|
|
3178
|
+
ApiField.TRACK_ID: track_id,
|
|
3179
|
+
}
|
|
3180
|
+
)
|
|
3181
|
+
|
|
3182
|
+
for object_id, figures_data in figure_data_by_object_id.items():
|
|
3183
|
+
figures_keys = [uuid.uuid4() for _ in figures_data]
|
|
3184
|
+
api.video.figure._append_bulk(
|
|
3185
|
+
entity_id=video_info.id,
|
|
3186
|
+
figures_json=figures_data,
|
|
3187
|
+
figures_keys=figures_keys,
|
|
3188
|
+
key_id_map=key_id_map,
|
|
3189
|
+
)
|
|
3190
|
+
logger.debug(f"Added {len(figures_data)} geometries to object #{object_id}")
|
|
3191
|
+
if progress_cb:
|
|
3192
|
+
progress_cb(len(predictions))
|
|
3193
|
+
if inference_request is not None:
|
|
3194
|
+
results = self._format_output(predictions)
|
|
3195
|
+
for result in results:
|
|
3196
|
+
result["annotation"] = None
|
|
3197
|
+
result["data"] = None
|
|
3198
|
+
inference_request.add_results(results)
|
|
3199
|
+
|
|
2792
3200
|
def serve(self):
|
|
2793
3201
|
if not self._use_gui and not self._is_cli_deploy:
|
|
2794
3202
|
Progress("Deploying model ...", 1)
|
|
@@ -2812,12 +3220,12 @@ class Inference:
|
|
|
2812
3220
|
# Predict and shutdown
|
|
2813
3221
|
if self._args.mode == "predict":
|
|
2814
3222
|
if any(
|
|
2815
|
-
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
|
|
2819
|
-
|
|
2820
|
-
|
|
3223
|
+
[
|
|
3224
|
+
self._args.input,
|
|
3225
|
+
self._args.project_id,
|
|
3226
|
+
self._args.dataset_id,
|
|
3227
|
+
self._args.image_id,
|
|
3228
|
+
]
|
|
2821
3229
|
):
|
|
2822
3230
|
self._parse_inference_settings_from_args()
|
|
2823
3231
|
self._inference_by_cli_deploy_args()
|
|
@@ -2898,6 +3306,11 @@ class Inference:
|
|
|
2898
3306
|
def get_session_info(response: Response):
|
|
2899
3307
|
return self.get_info()
|
|
2900
3308
|
|
|
3309
|
+
@server.post("/get_tracking_settings")
|
|
3310
|
+
@self._check_serve_before_call
|
|
3311
|
+
def get_tracking_settings(response: Response):
|
|
3312
|
+
return self.get_tracking_settings()
|
|
3313
|
+
|
|
2901
3314
|
@server.post("/get_custom_inference_settings")
|
|
2902
3315
|
def get_custom_inference_settings():
|
|
2903
3316
|
return {"settings": self.custom_inference_settings}
|
|
@@ -3181,6 +3594,22 @@ class Inference:
|
|
|
3181
3594
|
"inference_request_uuid": inference_request.uuid,
|
|
3182
3595
|
}
|
|
3183
3596
|
|
|
3597
|
+
@server.post("/tracking_by_detection")
|
|
3598
|
+
def tracking_by_detection(response: Response, request: Request):
|
|
3599
|
+
state = request.state.state
|
|
3600
|
+
context = request.state.context
|
|
3601
|
+
state.update(context)
|
|
3602
|
+
if state.get("tracker") is None:
|
|
3603
|
+
state["tracker"] = "botsort"
|
|
3604
|
+
|
|
3605
|
+
logger.debug("Received a request to 'tracking_by_detection'", extra={"state": state})
|
|
3606
|
+
self.validate_inference_state(state)
|
|
3607
|
+
api = self.api_from_request(request)
|
|
3608
|
+
inference_request, future = self.inference_requests_manager.schedule_task(
|
|
3609
|
+
self._tracking_by_detection, api, state
|
|
3610
|
+
)
|
|
3611
|
+
return {"message": "Track task started."}
|
|
3612
|
+
|
|
3184
3613
|
@server.post("/inference_project_id_async")
|
|
3185
3614
|
def inference_project_id_async(response: Response, request: Request):
|
|
3186
3615
|
state = request.state.state
|
|
@@ -3244,10 +3673,7 @@ class Inference:
|
|
|
3244
3673
|
data = {**inference_request.to_json(), **log_extra}
|
|
3245
3674
|
if inference_request.stage != InferenceRequest.Stage.INFERENCE:
|
|
3246
3675
|
data["progress"] = {"current": 0, "total": 1}
|
|
3247
|
-
logger.debug(
|
|
3248
|
-
f"Sending inference progress with uuid:",
|
|
3249
|
-
extra=data,
|
|
3250
|
-
)
|
|
3676
|
+
logger.debug(f"Sending inference progress with uuid:", extra=data)
|
|
3251
3677
|
return data
|
|
3252
3678
|
|
|
3253
3679
|
@server.post(f"/pop_inference_results")
|
|
@@ -3671,6 +4097,7 @@ class Inference:
|
|
|
3671
4097
|
|
|
3672
4098
|
def _parse_inference_settings_from_args(self):
|
|
3673
4099
|
logger.debug("Parsing inference settings from args")
|
|
4100
|
+
|
|
3674
4101
|
def parse_value(value: str):
|
|
3675
4102
|
if value.lower() in ("true", "false"):
|
|
3676
4103
|
return value.lower() == "true"
|
|
@@ -3797,8 +4224,7 @@ class Inference:
|
|
|
3797
4224
|
try:
|
|
3798
4225
|
# Read data from checkpoint
|
|
3799
4226
|
logger.debug(f"Reading data from checkpoint: {checkpoint_path}")
|
|
3800
|
-
|
|
3801
|
-
checkpoint = torch.load(checkpoint_path)
|
|
4227
|
+
checkpoint = torch_load_safe(checkpoint_path)
|
|
3802
4228
|
model_info = checkpoint["model_info"]
|
|
3803
4229
|
model_files = self._extract_model_files_from_checkpoint(checkpoint_path)
|
|
3804
4230
|
model_meta = os.path.join(self.model_dir, "model_meta.json")
|
|
@@ -4028,6 +4454,7 @@ class Inference:
|
|
|
4028
4454
|
draw: bool = False,
|
|
4029
4455
|
):
|
|
4030
4456
|
logger.info(f"Predicting Local Data: {input_path}")
|
|
4457
|
+
|
|
4031
4458
|
def postprocess_image(image_path: str, ann: Annotation, pred_dir: str = None):
|
|
4032
4459
|
image_name = sly_fs.get_file_name_with_ext(image_path)
|
|
4033
4460
|
if pred_dir is not None:
|
|
@@ -4103,6 +4530,20 @@ class Inference:
|
|
|
4103
4530
|
self._args.draw,
|
|
4104
4531
|
)
|
|
4105
4532
|
|
|
4533
|
+
def _apply_tracker_to_anns(self, frames: List[np.ndarray], anns: List[Annotation], tracker):
|
|
4534
|
+
updated_anns = []
|
|
4535
|
+
for frame, ann in zip(frames, anns):
|
|
4536
|
+
matches = tracker.update(frame, ann)
|
|
4537
|
+
track_ids = [match["track_id"] for match in matches]
|
|
4538
|
+
tracked_labels = [match["label"] for match in matches]
|
|
4539
|
+
|
|
4540
|
+
filtered_annotation = ann.clone(
|
|
4541
|
+
labels=tracked_labels,
|
|
4542
|
+
custom_data=track_ids
|
|
4543
|
+
)
|
|
4544
|
+
updated_anns.append(filtered_annotation)
|
|
4545
|
+
return updated_anns
|
|
4546
|
+
|
|
4106
4547
|
def _add_workflow_input(self, model_source: str, model_files: dict, model_info: dict):
|
|
4107
4548
|
if model_source == ModelSource.PRETRAINED:
|
|
4108
4549
|
checkpoint_url = model_info["meta"]["model_files"]["checkpoint"]
|
|
@@ -4136,13 +4577,14 @@ class Inference:
|
|
|
4136
4577
|
|
|
4137
4578
|
task_id = experiment_info.task_id
|
|
4138
4579
|
self.gui.model_source_tabs.set_active_tab(ModelSource.CUSTOM)
|
|
4139
|
-
self.gui.experiment_selector.
|
|
4580
|
+
self.gui.experiment_selector.set_selected_row_by_task_id(task_id)
|
|
4140
4581
|
|
|
4141
4582
|
best_ckpt = experiment_info.best_checkpoint
|
|
4142
4583
|
if best_ckpt:
|
|
4143
|
-
row = self.gui.experiment_selector.
|
|
4584
|
+
row = self.gui.experiment_selector.get_selected_row_by_task_id(task_id)
|
|
4144
4585
|
if row is not None:
|
|
4145
4586
|
row.set_selected_checkpoint_by_name(best_ckpt)
|
|
4587
|
+
|
|
4146
4588
|
except Exception as e:
|
|
4147
4589
|
logger.warning(f"Failed to set checkpoint from experiment info: {repr(e)}")
|
|
4148
4590
|
|
|
@@ -4151,61 +4593,78 @@ class Inference:
|
|
|
4151
4593
|
return
|
|
4152
4594
|
self.gui.model_source_tabs.set_active_tab(ModelSource.PRETRAINED)
|
|
4153
4595
|
|
|
4154
|
-
def
|
|
4155
|
-
|
|
4156
|
-
|
|
4157
|
-
|
|
4158
|
-
|
|
4159
|
-
|
|
4160
|
-
|
|
4596
|
+
def export_onnx(self, deploy_params: dict):
|
|
4597
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4598
|
+
|
|
4599
|
+
def export_tensorrt(self, deploy_params: dict):
|
|
4600
|
+
raise NotImplementedError("Have to be implemented in child class after inheritance")
|
|
4601
|
+
|
|
4602
|
+
|
|
4603
|
+
def _filter_duplicated_predictions_from_ann_cpu(
|
|
4604
|
+
gt_ann: Annotation, pred_ann: Annotation, iou_threshold: float
|
|
4161
4605
|
):
|
|
4162
4606
|
"""
|
|
4163
|
-
Filter out
|
|
4164
|
-
|
|
4165
|
-
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4166
|
-
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4167
|
-
- Gets ProjectMeta object if not provided
|
|
4168
|
-
- Downloads GT annotations for the specified image IDs
|
|
4169
|
-
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4607
|
+
Filter out predicted labels whose bboxes have IoU > iou_threshold with any GT label.
|
|
4608
|
+
Uses Shapely for geometric operations.
|
|
4170
4609
|
|
|
4171
|
-
:
|
|
4172
|
-
|
|
4173
|
-
|
|
4174
|
-
|
|
4175
|
-
:param dataset_id: ID of the dataset containing the images
|
|
4176
|
-
:type dataset_id: int
|
|
4177
|
-
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4178
|
-
:type gt_image_ids: List[int]
|
|
4179
|
-
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4180
|
-
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4181
|
-
:type iou: Optional[float]
|
|
4182
|
-
:param meta: ProjectMeta object
|
|
4183
|
-
:type meta: Optional[ProjectMeta]
|
|
4184
|
-
:return: List of Annotation objects containing filtered predictions
|
|
4185
|
-
:rtype: List[Annotation]
|
|
4610
|
+
Args:
|
|
4611
|
+
pred_ann: Predicted annotation object
|
|
4612
|
+
gt_ann: Ground truth annotation object
|
|
4613
|
+
iou_threshold: IoU threshold for filtering
|
|
4186
4614
|
|
|
4187
|
-
|
|
4188
|
-
|
|
4189
|
-
- Requires PyTorch and torchvision for IoU calculations
|
|
4190
|
-
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4615
|
+
Returns:
|
|
4616
|
+
New annotation with filtered labels
|
|
4191
4617
|
"""
|
|
4192
|
-
if
|
|
4193
|
-
|
|
4194
|
-
|
|
4195
|
-
|
|
4196
|
-
|
|
4197
|
-
|
|
4198
|
-
|
|
4199
|
-
|
|
4200
|
-
|
|
4201
|
-
|
|
4202
|
-
|
|
4203
|
-
|
|
4204
|
-
|
|
4205
|
-
|
|
4206
|
-
|
|
4207
|
-
|
|
4208
|
-
|
|
4618
|
+
if not iou_threshold:
|
|
4619
|
+
return pred_ann
|
|
4620
|
+
|
|
4621
|
+
from shapely.geometry import box
|
|
4622
|
+
|
|
4623
|
+
def calculate_iou(geom1: Geometry, geom2: Geometry):
|
|
4624
|
+
"""Calculate IoU between two geometries using Shapely."""
|
|
4625
|
+
bbox1 = geom1.to_bbox()
|
|
4626
|
+
bbox2 = geom2.to_bbox()
|
|
4627
|
+
|
|
4628
|
+
box1 = box(bbox1.left, bbox1.top, bbox1.right, bbox1.bottom)
|
|
4629
|
+
box2 = box(bbox2.left, bbox2.top, bbox2.right, bbox2.bottom)
|
|
4630
|
+
|
|
4631
|
+
intersection = box1.intersection(box2).area
|
|
4632
|
+
union = box1.union(box2).area
|
|
4633
|
+
|
|
4634
|
+
return intersection / union if union > 0 else 0.0
|
|
4635
|
+
|
|
4636
|
+
new_labels = []
|
|
4637
|
+
pred_cls_bboxes = defaultdict(list)
|
|
4638
|
+
for label in pred_ann.labels:
|
|
4639
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4640
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4641
|
+
|
|
4642
|
+
gt_cls_bboxes = defaultdict(list)
|
|
4643
|
+
for label in gt_ann.labels:
|
|
4644
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4645
|
+
if name_shape not in pred_cls_bboxes:
|
|
4646
|
+
continue
|
|
4647
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4648
|
+
|
|
4649
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4650
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4651
|
+
if len(gt) == 0:
|
|
4652
|
+
new_labels.extend(pred)
|
|
4653
|
+
continue
|
|
4654
|
+
|
|
4655
|
+
for pred_label in pred:
|
|
4656
|
+
# Check if this prediction has IoU < threshold with ALL GT boxes
|
|
4657
|
+
keep = True
|
|
4658
|
+
for gt_label in gt:
|
|
4659
|
+
iou = calculate_iou(pred_label.geometry, gt_label.geometry)
|
|
4660
|
+
if iou >= iou_threshold:
|
|
4661
|
+
keep = False
|
|
4662
|
+
break
|
|
4663
|
+
|
|
4664
|
+
if keep:
|
|
4665
|
+
new_labels.append(pred_label)
|
|
4666
|
+
|
|
4667
|
+
return pred_ann.clone(labels=new_labels)
|
|
4209
4668
|
|
|
4210
4669
|
|
|
4211
4670
|
def _filter_duplicated_predictions_from_ann(
|
|
@@ -4236,13 +4695,15 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4236
4695
|
- Predictions with classes not present in ground truth will be kept
|
|
4237
4696
|
- Requires PyTorch and torchvision for IoU calculations
|
|
4238
4697
|
"""
|
|
4698
|
+
if not iou_threshold:
|
|
4699
|
+
return pred_ann
|
|
4239
4700
|
|
|
4240
4701
|
try:
|
|
4241
4702
|
import torch
|
|
4242
4703
|
from torchvision.ops import box_iou
|
|
4243
4704
|
|
|
4244
4705
|
except ImportError:
|
|
4245
|
-
|
|
4706
|
+
return _filter_duplicated_predictions_from_ann_cpu(gt_ann, pred_ann, iou_threshold)
|
|
4246
4707
|
|
|
4247
4708
|
def _to_tensor(geom):
|
|
4248
4709
|
return torch.tensor([geom.left, geom.top, geom.right, geom.bottom]).float()
|
|
@@ -4250,16 +4711,18 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4250
4711
|
new_labels = []
|
|
4251
4712
|
pred_cls_bboxes = defaultdict(list)
|
|
4252
4713
|
for label in pred_ann.labels:
|
|
4253
|
-
|
|
4714
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4715
|
+
pred_cls_bboxes[name_shape].append(label)
|
|
4254
4716
|
|
|
4255
4717
|
gt_cls_bboxes = defaultdict(list)
|
|
4256
4718
|
for label in gt_ann.labels:
|
|
4257
|
-
|
|
4719
|
+
name_shape = (label.obj_class.name, label.geometry.name())
|
|
4720
|
+
if name_shape not in pred_cls_bboxes:
|
|
4258
4721
|
continue
|
|
4259
|
-
gt_cls_bboxes[
|
|
4722
|
+
gt_cls_bboxes[name_shape].append(label)
|
|
4260
4723
|
|
|
4261
|
-
for
|
|
4262
|
-
gt = gt_cls_bboxes[
|
|
4724
|
+
for name_shape, pred in pred_cls_bboxes.items():
|
|
4725
|
+
gt = gt_cls_bboxes[name_shape]
|
|
4263
4726
|
if len(gt) == 0:
|
|
4264
4727
|
new_labels.extend(pred)
|
|
4265
4728
|
continue
|
|
@@ -4273,6 +4736,63 @@ def _filter_duplicated_predictions_from_ann(
|
|
|
4273
4736
|
return pred_ann.clone(labels=new_labels)
|
|
4274
4737
|
|
|
4275
4738
|
|
|
4739
|
+
def _exclude_duplicated_predictions(
|
|
4740
|
+
api: Api,
|
|
4741
|
+
pred_anns: List[Annotation],
|
|
4742
|
+
dataset_id: int,
|
|
4743
|
+
gt_image_ids: List[int],
|
|
4744
|
+
iou: float = None,
|
|
4745
|
+
meta: Optional[ProjectMeta] = None,
|
|
4746
|
+
):
|
|
4747
|
+
"""
|
|
4748
|
+
Filter out predictions that significantly overlap with ground truth (GT) objects.
|
|
4749
|
+
|
|
4750
|
+
This is a wrapper around the `_filter_duplicated_predictions_from_ann` method that does the following:
|
|
4751
|
+
- Checks inference settings for the IoU threshold (`existing_objects_iou_thresh`)
|
|
4752
|
+
- Gets ProjectMeta object if not provided
|
|
4753
|
+
- Downloads GT annotations for the specified image IDs
|
|
4754
|
+
- Filters out predictions that have an IoU greater than or equal to the specified threshold with any GT object
|
|
4755
|
+
|
|
4756
|
+
:param api: Supervisely API object
|
|
4757
|
+
:type api: Api
|
|
4758
|
+
:param pred_anns: List of Annotation objects containing predictions
|
|
4759
|
+
:type pred_anns: List[Annotation]
|
|
4760
|
+
:param dataset_id: ID of the dataset containing the images
|
|
4761
|
+
:type dataset_id: int
|
|
4762
|
+
:param gt_image_ids: List of image IDs to filter predictions. All images should belong to the same dataset
|
|
4763
|
+
:type gt_image_ids: List[int]
|
|
4764
|
+
:param iou: IoU threshold (0.0-1.0). Predictions with IoU >= threshold with any
|
|
4765
|
+
ground truth box of the same class will be removed. None if no filtering is needed
|
|
4766
|
+
:type iou: Optional[float]
|
|
4767
|
+
:param meta: ProjectMeta object
|
|
4768
|
+
:type meta: Optional[ProjectMeta]
|
|
4769
|
+
:return: List of Annotation objects containing filtered predictions
|
|
4770
|
+
:rtype: List[Annotation]
|
|
4771
|
+
|
|
4772
|
+
Notes:
|
|
4773
|
+
------
|
|
4774
|
+
- Requires PyTorch and torchvision for IoU calculations
|
|
4775
|
+
- This method is useful for identifying new objects that aren't already annotated in the ground truth
|
|
4776
|
+
"""
|
|
4777
|
+
if isinstance(iou, float) and 0 < iou <= 1:
|
|
4778
|
+
if meta is None:
|
|
4779
|
+
ds = api.dataset.get_info_by_id(dataset_id)
|
|
4780
|
+
meta = ProjectMeta.from_json(api.project.get_meta(ds.project_id))
|
|
4781
|
+
gt_anns = api.annotation.download_json_batch(dataset_id, gt_image_ids)
|
|
4782
|
+
gt_anns = [Annotation.from_json(ann, meta) for ann in gt_anns]
|
|
4783
|
+
for i in range(0, len(pred_anns)):
|
|
4784
|
+
before = len(pred_anns[i].labels)
|
|
4785
|
+
with Timer() as timer:
|
|
4786
|
+
pred_anns[i] = _filter_duplicated_predictions_from_ann(
|
|
4787
|
+
gt_anns[i], pred_anns[i], iou
|
|
4788
|
+
)
|
|
4789
|
+
after = len(pred_anns[i].labels)
|
|
4790
|
+
logger.debug(
|
|
4791
|
+
f"{[i]}: applied NMS with IoU={iou}. Before: {before}, After: {after}. Time: {timer.get_time():.3f}ms"
|
|
4792
|
+
)
|
|
4793
|
+
return pred_anns
|
|
4794
|
+
|
|
4795
|
+
|
|
4276
4796
|
def _get_log_extra_for_inference_request(
|
|
4277
4797
|
inference_request_uuid, inference_request: Union[InferenceRequest, dict]
|
|
4278
4798
|
):
|
|
@@ -4299,8 +4819,8 @@ def _get_log_extra_for_inference_request(
|
|
|
4299
4819
|
"has_result": inference_request.final_result is not None,
|
|
4300
4820
|
"pending_results": inference_request.pending_num(),
|
|
4301
4821
|
"exception": inference_request.exception_json(),
|
|
4302
|
-
"result": inference_request._final_result,
|
|
4303
4822
|
"preparing_progress": progress,
|
|
4823
|
+
"result": inference_request.final_result is not None, # for backward compatibility
|
|
4304
4824
|
}
|
|
4305
4825
|
return log_extra
|
|
4306
4826
|
|
|
@@ -4380,7 +4900,7 @@ def get_gpu_count():
|
|
|
4380
4900
|
gpu_count = len(re.findall(r"GPU \d+:", nvidia_smi_output))
|
|
4381
4901
|
return gpu_count
|
|
4382
4902
|
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
|
4383
|
-
logger.
|
|
4903
|
+
logger.warning("Calling nvidia-smi caused a error: {exc}. Assume there is no any GPU.")
|
|
4384
4904
|
return 0
|
|
4385
4905
|
|
|
4386
4906
|
|
|
@@ -4426,7 +4946,7 @@ def _fix_classes_names(meta: ProjectMeta, ann: Annotation):
|
|
|
4426
4946
|
return meta, ann, replaced_classes_in_meta, list(replaced_classes_in_ann)
|
|
4427
4947
|
|
|
4428
4948
|
|
|
4429
|
-
def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
4949
|
+
def update_meta_and_ann(meta: ProjectMeta, ann: Annotation, model_prediction_suffix: str = None):
|
|
4430
4950
|
"""Update project meta and annotation to match each other
|
|
4431
4951
|
If obj class or tag meta from annotation conflicts with project meta
|
|
4432
4952
|
add suffix to obj class or tag meta.
|
|
@@ -4434,8 +4954,13 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
|
4434
4954
|
"""
|
|
4435
4955
|
obj_classes_suffixes = ["_nn"]
|
|
4436
4956
|
tag_meta_suffixes = ["_nn"]
|
|
4437
|
-
|
|
4438
|
-
|
|
4957
|
+
if model_prediction_suffix is not None:
|
|
4958
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
4959
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
4960
|
+
logger.debug(
|
|
4961
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
4962
|
+
)
|
|
4963
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
4439
4964
|
meta_changed = False
|
|
4440
4965
|
|
|
4441
4966
|
meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
@@ -4446,91 +4971,289 @@ def update_meta_and_ann(meta: ProjectMeta, ann: Annotation):
|
|
|
4446
4971
|
extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
4447
4972
|
)
|
|
4448
4973
|
|
|
4449
|
-
|
|
4974
|
+
updated_labels = []
|
|
4975
|
+
any_label_updated = False
|
|
4976
|
+
for label in ann.labels:
|
|
4977
|
+
original_obj_class_name = label.obj_class.name
|
|
4978
|
+
suffix_found = False
|
|
4979
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
4980
|
+
label_obj_class = label.obj_class
|
|
4981
|
+
label_obj_class_name = label_obj_class.name + suffix
|
|
4982
|
+
if suffix:
|
|
4983
|
+
label_obj_class = label_obj_class.clone(name=label_obj_class_name)
|
|
4984
|
+
label = label.clone(obj_class=label_obj_class)
|
|
4985
|
+
any_label_updated = True
|
|
4986
|
+
meta_obj_class = meta.get_obj_class(label_obj_class_name)
|
|
4987
|
+
if meta_obj_class is None:
|
|
4988
|
+
# if obj class is not in meta, add it with suffix
|
|
4989
|
+
meta = meta.add_obj_class(label_obj_class)
|
|
4990
|
+
updated_labels.append(label)
|
|
4991
|
+
meta_changed = True
|
|
4992
|
+
suffix_found = True
|
|
4993
|
+
break
|
|
4994
|
+
elif meta_obj_class.geometry_type.geometry_name() == label.geometry.geometry_name():
|
|
4995
|
+
# if label geometry is the same as in meta, use meta obj class
|
|
4996
|
+
label = label.clone(obj_class=meta_obj_class)
|
|
4997
|
+
updated_labels.append(label)
|
|
4998
|
+
suffix_found = True
|
|
4999
|
+
any_label_updated = True
|
|
5000
|
+
break
|
|
5001
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
5002
|
+
# if meta obj class is AnyGeometry, use it in label
|
|
5003
|
+
label = label.clone(obj_class=meta_obj_class)
|
|
5004
|
+
updated_labels.append(label)
|
|
5005
|
+
suffix_found = True
|
|
5006
|
+
any_label_updated = True
|
|
5007
|
+
break
|
|
5008
|
+
if not suffix_found:
|
|
5009
|
+
# if no suffix found, raise error
|
|
5010
|
+
raise ValueError(
|
|
5011
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
5012
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
5013
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
5014
|
+
)
|
|
5015
|
+
if any_label_updated:
|
|
5016
|
+
ann = ann.clone(labels=updated_labels)
|
|
5017
|
+
|
|
5018
|
+
# check if tag metas are in project meta
|
|
5019
|
+
# if not, add them with suffix
|
|
5020
|
+
ann_tag_metas = {}
|
|
4450
5021
|
for label in ann.labels:
|
|
4451
|
-
ann_obj_classes[label.obj_class.name] = label.obj_class
|
|
4452
5022
|
for tag in label.tags:
|
|
4453
5023
|
ann_tag_metas[tag.meta.name] = tag.meta
|
|
4454
5024
|
for tag in ann.img_tags:
|
|
4455
5025
|
ann_tag_metas[tag.meta.name] = tag.meta
|
|
4456
5026
|
|
|
4457
|
-
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
if meta.get_obj_class(ann_obj_class.name) is None:
|
|
4463
|
-
meta = meta.add_obj_class(ann_obj_class)
|
|
5027
|
+
changed_tag_metas = {}
|
|
5028
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
5029
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
5030
|
+
if meta_tag_meta is None:
|
|
5031
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
4464
5032
|
meta_changed = True
|
|
4465
|
-
elif (
|
|
4466
|
-
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
new_obj_class = ann_obj_class.clone(name=new_obj_class_name)
|
|
4475
|
-
meta = meta.add_obj_class(new_obj_class)
|
|
5033
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5034
|
+
suffix_found = False
|
|
5035
|
+
for suffix in tag_meta_suffixes:
|
|
5036
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
5037
|
+
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
5038
|
+
if meta_tag_meta is None:
|
|
5039
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
5040
|
+
meta = meta.add_tag_meta(new_tag_meta)
|
|
5041
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
4476
5042
|
meta_changed = True
|
|
4477
|
-
|
|
4478
|
-
found = True
|
|
5043
|
+
suffix_found = True
|
|
4479
5044
|
break
|
|
4480
|
-
if
|
|
4481
|
-
|
|
4482
|
-
|
|
5045
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5046
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
5047
|
+
suffix_found = True
|
|
4483
5048
|
break
|
|
4484
|
-
if not
|
|
4485
|
-
raise ValueError(f"Can't add
|
|
5049
|
+
if not suffix_found:
|
|
5050
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
5051
|
+
|
|
5052
|
+
if changed_tag_metas:
|
|
5053
|
+
labels = []
|
|
5054
|
+
any_label_updated = False
|
|
5055
|
+
for label in ann.labels:
|
|
5056
|
+
any_tag_updated = False
|
|
5057
|
+
label_tags = []
|
|
5058
|
+
for tag in label.tags:
|
|
5059
|
+
if tag.meta.name in changed_tag_metas:
|
|
5060
|
+
label_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5061
|
+
any_tag_updated = True
|
|
5062
|
+
else:
|
|
5063
|
+
label_tags.append(tag)
|
|
5064
|
+
if any_tag_updated:
|
|
5065
|
+
label = label.clone(tags=TagCollection(label_tags))
|
|
5066
|
+
any_label_updated = True
|
|
5067
|
+
labels.append(label)
|
|
5068
|
+
img_tags = []
|
|
5069
|
+
any_tag_updated = False
|
|
5070
|
+
for tag in ann.img_tags:
|
|
5071
|
+
if tag.meta.name in changed_tag_metas:
|
|
5072
|
+
img_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5073
|
+
any_tag_updated = True
|
|
5074
|
+
else:
|
|
5075
|
+
img_tags.append(tag)
|
|
5076
|
+
if any_tag_updated or any_label_updated:
|
|
5077
|
+
if any_tag_updated:
|
|
5078
|
+
img_tags = TagCollection(img_tags)
|
|
5079
|
+
else:
|
|
5080
|
+
img_tags = None
|
|
5081
|
+
if not any_label_updated:
|
|
5082
|
+
labels = None
|
|
5083
|
+
ann = ann.clone(img_tags=img_tags)
|
|
5084
|
+
return meta, ann, meta_changed
|
|
5085
|
+
|
|
5086
|
+
|
|
5087
|
+
def update_meta_and_ann_for_video_annotation(
|
|
5088
|
+
meta: ProjectMeta, ann: VideoAnnotation, model_prediction_suffix: str = None
|
|
5089
|
+
):
|
|
5090
|
+
"""Update project meta and annotation to match each other
|
|
5091
|
+
If obj class or tag meta from annotation conflicts with project meta
|
|
5092
|
+
add suffix to obj class or tag meta.
|
|
5093
|
+
Return tuple of updated project meta, annotation and boolean flag if meta was changed.
|
|
5094
|
+
"""
|
|
5095
|
+
obj_classes_suffixes = ["_nn"]
|
|
5096
|
+
tag_meta_suffixes = ["_nn"]
|
|
5097
|
+
if model_prediction_suffix is not None:
|
|
5098
|
+
obj_classes_suffixes = [model_prediction_suffix]
|
|
5099
|
+
tag_meta_suffixes = [model_prediction_suffix]
|
|
5100
|
+
logger.debug(
|
|
5101
|
+
f"Using custom suffixes for obj classes and tag metas: {obj_classes_suffixes}, {tag_meta_suffixes}"
|
|
5102
|
+
)
|
|
5103
|
+
logger.debug("source meta", extra={"meta": meta.to_json()})
|
|
5104
|
+
meta_changed = False
|
|
5105
|
+
|
|
5106
|
+
# meta, ann, replaced_classes_in_meta, replaced_classes_in_ann = _fix_classes_names(meta, ann)
|
|
5107
|
+
# if replaced_classes_in_meta:
|
|
5108
|
+
# meta_changed = True
|
|
5109
|
+
# logger.warning(
|
|
5110
|
+
# "Some classes names were fixed in project meta",
|
|
5111
|
+
# extra={"replaced_classes": {old: new for old, new in replaced_classes_in_meta}},
|
|
5112
|
+
# )
|
|
5113
|
+
|
|
5114
|
+
new_objects: List[VideoObject] = []
|
|
5115
|
+
new_figures: List[VideoFigure] = []
|
|
5116
|
+
any_object_updated = False
|
|
5117
|
+
for video_object in ann.objects:
|
|
5118
|
+
this_object_figures = [
|
|
5119
|
+
figure for figure in ann.figures if figure.video_object.key() == video_object.key()
|
|
5120
|
+
]
|
|
5121
|
+
this_object_changed = False
|
|
5122
|
+
original_obj_class_name = video_object.obj_class.name
|
|
5123
|
+
suffix_found = False
|
|
5124
|
+
for suffix in ["", *obj_classes_suffixes]:
|
|
5125
|
+
obj_class = video_object.obj_class
|
|
5126
|
+
obj_class_name = obj_class.name + suffix
|
|
5127
|
+
if suffix:
|
|
5128
|
+
obj_class = obj_class.clone(name=obj_class_name)
|
|
5129
|
+
video_object = video_object.clone(obj_class=obj_class)
|
|
5130
|
+
any_object_updated = True
|
|
5131
|
+
this_object_changed = True
|
|
5132
|
+
meta_obj_class = meta.get_obj_class(obj_class_name)
|
|
5133
|
+
if meta_obj_class is None:
|
|
5134
|
+
# obj class is not in meta, add it with suffix
|
|
5135
|
+
meta = meta.add_obj_class(obj_class)
|
|
5136
|
+
new_objects.append(video_object)
|
|
5137
|
+
meta_changed = True
|
|
5138
|
+
suffix_found = True
|
|
5139
|
+
break
|
|
5140
|
+
elif (
|
|
5141
|
+
meta_obj_class.geometry_type.geometry_name()
|
|
5142
|
+
== video_object.obj_class.geometry_type.geometry_name()
|
|
5143
|
+
):
|
|
5144
|
+
# if object geometry is the same as in meta, use meta obj class
|
|
5145
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5146
|
+
new_objects.append(video_object)
|
|
5147
|
+
suffix_found = True
|
|
5148
|
+
any_object_updated = True
|
|
5149
|
+
this_object_changed = True
|
|
5150
|
+
break
|
|
5151
|
+
elif meta_obj_class.geometry_type.geometry_name() == AnyGeometry.geometry_name():
|
|
5152
|
+
# if meta obj class is AnyGeometry, use it in object
|
|
5153
|
+
video_object = video_object.clone(obj_class=meta_obj_class)
|
|
5154
|
+
new_objects.append(video_object)
|
|
5155
|
+
suffix_found = True
|
|
5156
|
+
any_object_updated = True
|
|
5157
|
+
this_object_changed = True
|
|
5158
|
+
break
|
|
5159
|
+
if not suffix_found:
|
|
5160
|
+
# if no suffix found, raise error
|
|
5161
|
+
raise ValueError(
|
|
5162
|
+
f"Can't add obj class {original_obj_class_name} to project meta. "
|
|
5163
|
+
"Tried with suffixes: " + ", ".join(obj_classes_suffixes) + ". "
|
|
5164
|
+
"Please check if model geometry type is compatible with existing obj classes."
|
|
5165
|
+
)
|
|
5166
|
+
elif this_object_changed:
|
|
5167
|
+
this_object_figures = [
|
|
5168
|
+
figure.clone(video_object=video_object) for figure in this_object_figures
|
|
5169
|
+
]
|
|
5170
|
+
new_figures.extend(this_object_figures)
|
|
5171
|
+
if any_object_updated:
|
|
5172
|
+
frames_figures = {}
|
|
5173
|
+
for figure in new_figures:
|
|
5174
|
+
frames_figures.setdefault(figure.frame_index, []).append(figure)
|
|
5175
|
+
new_frames = FrameCollection(
|
|
5176
|
+
[
|
|
5177
|
+
Frame(index=frame_index, figures=figures)
|
|
5178
|
+
for frame_index, figures in frames_figures.items()
|
|
5179
|
+
]
|
|
5180
|
+
)
|
|
5181
|
+
ann = ann.clone(objects=new_objects, frames=new_frames)
|
|
4486
5182
|
|
|
4487
5183
|
# check if tag metas are in project meta
|
|
4488
5184
|
# if not, add them with suffix
|
|
5185
|
+
ann_tag_metas: Dict[str, TagMeta] = {}
|
|
5186
|
+
for video_object in ann.objects:
|
|
5187
|
+
for tag in video_object.tags:
|
|
5188
|
+
tag_name = tag.meta.name
|
|
5189
|
+
if tag_name not in ann_tag_metas:
|
|
5190
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5191
|
+
for tag in ann.tags:
|
|
5192
|
+
tag_name = tag.meta.name
|
|
5193
|
+
if tag_name not in ann_tag_metas:
|
|
5194
|
+
ann_tag_metas[tag_name] = tag.meta
|
|
5195
|
+
|
|
4489
5196
|
changed_tag_metas = {}
|
|
4490
|
-
for
|
|
4491
|
-
|
|
4492
|
-
|
|
5197
|
+
for ann_tag_meta in ann_tag_metas.values():
|
|
5198
|
+
meta_tag_meta = meta.get_tag_meta(ann_tag_meta.name)
|
|
5199
|
+
if meta_tag_meta is None:
|
|
5200
|
+
meta = meta.add_tag_meta(ann_tag_meta)
|
|
4493
5201
|
meta_changed = True
|
|
4494
|
-
elif not
|
|
4495
|
-
|
|
5202
|
+
elif not meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5203
|
+
suffix_found = False
|
|
4496
5204
|
for suffix in tag_meta_suffixes:
|
|
4497
|
-
new_tag_meta_name =
|
|
5205
|
+
new_tag_meta_name = ann_tag_meta.name + suffix
|
|
4498
5206
|
meta_tag_meta = meta.get_tag_meta(new_tag_meta_name)
|
|
4499
5207
|
if meta_tag_meta is None:
|
|
4500
|
-
new_tag_meta =
|
|
5208
|
+
new_tag_meta = ann_tag_meta.clone(name=new_tag_meta_name)
|
|
4501
5209
|
meta = meta.add_tag_meta(new_tag_meta)
|
|
4502
|
-
changed_tag_metas[
|
|
5210
|
+
changed_tag_metas[ann_tag_meta.name] = new_tag_meta
|
|
4503
5211
|
meta_changed = True
|
|
4504
|
-
|
|
5212
|
+
suffix_found = True
|
|
4505
5213
|
break
|
|
4506
|
-
if meta_tag_meta.is_compatible(
|
|
4507
|
-
changed_tag_metas[
|
|
4508
|
-
|
|
5214
|
+
if meta_tag_meta.is_compatible(ann_tag_meta):
|
|
5215
|
+
changed_tag_metas[ann_tag_meta.name] = meta_tag_meta
|
|
5216
|
+
suffix_found = True
|
|
4509
5217
|
break
|
|
4510
|
-
if not
|
|
4511
|
-
raise ValueError(f"Can't add tag meta {
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
|
|
4515
|
-
|
|
4516
|
-
|
|
4517
|
-
|
|
4518
|
-
|
|
4519
|
-
|
|
5218
|
+
if not suffix_found:
|
|
5219
|
+
raise ValueError(f"Can't add tag meta {ann_tag_meta.name} to project meta")
|
|
5220
|
+
|
|
5221
|
+
if changed_tag_metas:
|
|
5222
|
+
objects = []
|
|
5223
|
+
any_object_updated = False
|
|
5224
|
+
for video_object in ann.objects:
|
|
5225
|
+
any_tag_updated = False
|
|
5226
|
+
object_tags = []
|
|
5227
|
+
for tag in video_object.tags:
|
|
5228
|
+
if tag.meta.name in changed_tag_metas:
|
|
5229
|
+
object_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5230
|
+
any_tag_updated = True
|
|
5231
|
+
else:
|
|
5232
|
+
object_tags.append(tag)
|
|
5233
|
+
if any_tag_updated:
|
|
5234
|
+
video_object = video_object.clone(tags=TagCollection(object_tags))
|
|
5235
|
+
any_object_updated = True
|
|
5236
|
+
objects.append(video_object)
|
|
5237
|
+
|
|
5238
|
+
video_tags = []
|
|
5239
|
+
any_tag_updated = False
|
|
5240
|
+
for tag in ann.tags:
|
|
4520
5241
|
if tag.meta.name in changed_tag_metas:
|
|
4521
|
-
|
|
5242
|
+
video_tags.append(tag.clone(meta=changed_tag_metas[tag.meta.name]))
|
|
5243
|
+
any_tag_updated = True
|
|
4522
5244
|
else:
|
|
4523
|
-
|
|
4524
|
-
|
|
4525
|
-
|
|
4526
|
-
|
|
4527
|
-
|
|
4528
|
-
|
|
4529
|
-
|
|
4530
|
-
|
|
4531
|
-
|
|
5245
|
+
video_tags.append(tag)
|
|
5246
|
+
if any_tag_updated or any_object_updated:
|
|
5247
|
+
if any_tag_updated:
|
|
5248
|
+
video_tags = VideoTagCollection(video_tags)
|
|
5249
|
+
else:
|
|
5250
|
+
video_tags = None
|
|
5251
|
+
if any_object_updated:
|
|
5252
|
+
objects = VideoObjectCollection(objects)
|
|
5253
|
+
else:
|
|
5254
|
+
objects = None
|
|
5255
|
+
ann = ann.clone(tags=video_tags, objects=objects)
|
|
4532
5256
|
|
|
4533
|
-
ann = ann.clone(labels=labels, img_tags=TagCollection(img_tags))
|
|
4534
5257
|
return meta, ann, meta_changed
|
|
4535
5258
|
|
|
4536
5259
|
|
|
@@ -4643,3 +5366,22 @@ def get_value_for_keys(data: dict, keys: List, ignore_none: bool = False):
|
|
|
4643
5366
|
continue
|
|
4644
5367
|
return data[key]
|
|
4645
5368
|
return None
|
|
5369
|
+
|
|
5370
|
+
|
|
5371
|
+
def torch_load_safe(checkpoint_path: str, device: str = "cpu"):
|
|
5372
|
+
import torch # pylint: disable=import-error
|
|
5373
|
+
|
|
5374
|
+
# TODO: handle torch.load(weights_only=True) - change in torch 2.6.0
|
|
5375
|
+
try:
|
|
5376
|
+
logger.debug(f"Loading checkpoint from {checkpoint_path} on {device}")
|
|
5377
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
5378
|
+
logger.debug(f"Checkpoint loaded from {checkpoint_path} on {device}")
|
|
5379
|
+
except:
|
|
5380
|
+
logger.debug(
|
|
5381
|
+
f"Failed to load checkpoint from {checkpoint_path} on {device}. Trying again with weights_only=False"
|
|
5382
|
+
)
|
|
5383
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
5384
|
+
logger.debug(
|
|
5385
|
+
f"Checkpoint loaded from {checkpoint_path} on {device} with weights_only=False"
|
|
5386
|
+
)
|
|
5387
|
+
return checkpoint
|