supervisely 6.73.438__py3-none-any.whl → 6.73.513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/__init__.py +137 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/annotation.py +8 -2
- supervisely/annotation/json_geometries_map.py +14 -11
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +14 -11
- supervisely/api/api.py +59 -38
- supervisely/api/app_api.py +11 -2
- supervisely/api/dataset_api.py +74 -12
- supervisely/api/entities_collection_api.py +10 -0
- supervisely/api/entity_annotation/figure_api.py +52 -4
- supervisely/api/entity_annotation/object_api.py +3 -3
- supervisely/api/entity_annotation/tag_api.py +63 -12
- supervisely/api/guides_api.py +210 -0
- supervisely/api/image_api.py +72 -1
- supervisely/api/labeling_job_api.py +83 -1
- supervisely/api/labeling_queue_api.py +33 -7
- supervisely/api/module_api.py +9 -0
- supervisely/api/project_api.py +71 -26
- supervisely/api/storage_api.py +3 -1
- supervisely/api/task_api.py +13 -2
- supervisely/api/team_api.py +4 -3
- supervisely/api/video/video_annotation_api.py +119 -3
- supervisely/api/video/video_api.py +65 -14
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +23 -7
- supervisely/app/development/development.py +18 -2
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +105 -0
- supervisely/app/fastapi/subapp.py +88 -42
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +6 -0
- supervisely/app/widgets/activity_feed/__init__.py +0 -0
- supervisely/app/widgets/activity_feed/activity_feed.py +239 -0
- supervisely/app/widgets/activity_feed/style.css +78 -0
- supervisely/app/widgets/activity_feed/template.html +22 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/classes_list_selector/classes_list_selector.py +121 -9
- supervisely/app/widgets/classes_list_selector/template.html +60 -93
- supervisely/app/widgets/classes_mapping/classes_mapping.py +13 -12
- supervisely/app/widgets/classes_table/classes_table.py +1 -0
- supervisely/app/widgets/deploy_model/deploy_model.py +56 -35
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +1 -1
- supervisely/app/widgets/experiment_selector/experiment_selector.py +8 -0
- supervisely/app/widgets/fast_table/fast_table.py +184 -60
- supervisely/app/widgets/fast_table/template.html +1 -1
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +564 -0
- supervisely/app/widgets/heatmap/script.js +533 -0
- supervisely/app/widgets/heatmap/style.css +233 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/modal/__init__.py +0 -0
- supervisely/app/widgets/modal/modal.py +198 -0
- supervisely/app/widgets/modal/template.html +10 -0
- supervisely/app/widgets/object_class_view/object_class_view.py +3 -0
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -3
- supervisely/app/widgets/select_class/__init__.py +0 -0
- supervisely/app/widgets/select_class/select_class.py +363 -0
- supervisely/app/widgets/select_class/template.html +50 -0
- supervisely/app/widgets/select_cuda/select_cuda.py +22 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +65 -7
- supervisely/app/widgets/select_tag/__init__.py +0 -0
- supervisely/app/widgets/select_tag/select_tag.py +352 -0
- supervisely/app/widgets/select_tag/template.html +64 -0
- supervisely/app/widgets/select_team/select_team.py +37 -4
- supervisely/app/widgets/select_team/template.html +4 -5
- supervisely/app/widgets/select_user/__init__.py +0 -0
- supervisely/app/widgets/select_user/select_user.py +270 -0
- supervisely/app/widgets/select_user/template.html +13 -0
- supervisely/app/widgets/select_workspace/select_workspace.py +59 -10
- supervisely/app/widgets/select_workspace/template.html +9 -12
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/aug/aug.py +6 -2
- supervisely/convert/base_converter.py +1 -0
- supervisely/convert/converter.py +2 -2
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/image/image_converter.py +3 -1
- supervisely/convert/image/image_helper.py +48 -4
- supervisely/convert/image/label_studio/label_studio_converter.py +2 -0
- supervisely/convert/image/medical2d/medical2d_helper.py +2 -24
- supervisely/convert/image/multispectral/multispectral_converter.py +6 -0
- supervisely/convert/image/pascal_voc/pascal_voc_converter.py +8 -5
- supervisely/convert/image/pascal_voc/pascal_voc_helper.py +7 -0
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_converter.py +33 -3
- supervisely/convert/pointcloud/kitti_3d/kitti_3d_helper.py +12 -5
- supervisely/convert/pointcloud/las/las_converter.py +13 -1
- supervisely/convert/pointcloud/las/las_helper.py +110 -11
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +27 -16
- supervisely/convert/pointcloud/pointcloud_converter.py +91 -3
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +58 -22
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +21 -47
- supervisely/convert/video/__init__.py +1 -0
- supervisely/convert/video/multi_view/__init__.py +0 -0
- supervisely/convert/video/multi_view/multi_view.py +543 -0
- supervisely/convert/video/sly/sly_video_converter.py +359 -3
- supervisely/convert/video/video_converter.py +24 -4
- supervisely/convert/volume/dicom/dicom_converter.py +13 -5
- supervisely/convert/volume/dicom/dicom_helper.py +30 -18
- supervisely/geometry/constants.py +1 -0
- supervisely/geometry/geometry.py +4 -0
- supervisely/geometry/helpers.py +5 -1
- supervisely/geometry/oriented_bbox.py +676 -0
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/geometry/rectangle.py +2 -1
- supervisely/io/env.py +76 -1
- supervisely/io/fs.py +21 -0
- supervisely/nn/benchmark/base_evaluator.py +104 -11
- supervisely/nn/benchmark/instance_segmentation/evaluator.py +1 -8
- supervisely/nn/benchmark/object_detection/evaluator.py +20 -4
- supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +10 -5
- supervisely/nn/benchmark/semantic_segmentation/evaluator.py +34 -16
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +1 -1
- supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +1 -1
- supervisely/nn/benchmark/visualization/evaluation_result.py +66 -4
- supervisely/nn/inference/cache.py +43 -18
- supervisely/nn/inference/gui/serving_gui_template.py +5 -2
- supervisely/nn/inference/inference.py +916 -222
- supervisely/nn/inference/inference_request.py +55 -10
- supervisely/nn/inference/predict_app/gui/classes_selector.py +83 -12
- supervisely/nn/inference/predict_app/gui/gui.py +676 -488
- supervisely/nn/inference/predict_app/gui/input_selector.py +205 -26
- supervisely/nn/inference/predict_app/gui/model_selector.py +2 -4
- supervisely/nn/inference/predict_app/gui/output_selector.py +46 -6
- supervisely/nn/inference/predict_app/gui/settings_selector.py +756 -59
- supervisely/nn/inference/predict_app/gui/tags_selector.py +1 -1
- supervisely/nn/inference/predict_app/gui/utils.py +236 -119
- supervisely/nn/inference/predict_app/predict_app.py +2 -2
- supervisely/nn/inference/session.py +43 -35
- supervisely/nn/inference/tracking/bbox_tracking.py +118 -35
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +10 -1
- supervisely/nn/inference/uploader.py +139 -12
- supervisely/nn/live_training/__init__.py +7 -0
- supervisely/nn/live_training/api_server.py +111 -0
- supervisely/nn/live_training/artifacts_utils.py +243 -0
- supervisely/nn/live_training/checkpoint_utils.py +229 -0
- supervisely/nn/live_training/dynamic_sampler.py +44 -0
- supervisely/nn/live_training/helpers.py +14 -0
- supervisely/nn/live_training/incremental_dataset.py +146 -0
- supervisely/nn/live_training/live_training.py +497 -0
- supervisely/nn/live_training/loss_plateau_detector.py +111 -0
- supervisely/nn/live_training/request_queue.py +52 -0
- supervisely/nn/model/model_api.py +9 -0
- supervisely/nn/model/prediction.py +2 -1
- supervisely/nn/model/prediction_session.py +26 -14
- supervisely/nn/prediction_dto.py +19 -1
- supervisely/nn/tracker/base_tracker.py +11 -1
- supervisely/nn/tracker/botsort/botsort_config.yaml +0 -1
- supervisely/nn/tracker/botsort/tracker/mc_bot_sort.py +7 -4
- supervisely/nn/tracker/botsort_tracker.py +94 -65
- supervisely/nn/tracker/utils.py +4 -5
- supervisely/nn/tracker/visualize.py +93 -93
- supervisely/nn/training/gui/classes_selector.py +16 -1
- supervisely/nn/training/gui/train_val_splits_selector.py +52 -31
- supervisely/nn/training/train_app.py +46 -31
- supervisely/project/data_version.py +115 -51
- supervisely/project/download.py +1 -1
- supervisely/project/pointcloud_episode_project.py +37 -8
- supervisely/project/pointcloud_project.py +30 -2
- supervisely/project/project.py +14 -2
- supervisely/project/project_meta.py +27 -1
- supervisely/project/project_settings.py +32 -18
- supervisely/project/versioning/__init__.py +1 -0
- supervisely/project/versioning/common.py +20 -0
- supervisely/project/versioning/schema_fields.py +35 -0
- supervisely/project/versioning/video_schema.py +221 -0
- supervisely/project/versioning/volume_schema.py +87 -0
- supervisely/project/video_project.py +717 -15
- supervisely/project/volume_project.py +623 -5
- supervisely/template/experiment/experiment.html.jinja +4 -4
- supervisely/template/experiment/experiment_generator.py +14 -21
- supervisely/template/live_training/__init__.py +0 -0
- supervisely/template/live_training/header.html.jinja +96 -0
- supervisely/template/live_training/live_training.html.jinja +51 -0
- supervisely/template/live_training/live_training_generator.py +464 -0
- supervisely/template/live_training/sly-style.css +402 -0
- supervisely/template/live_training/template.html.jinja +18 -0
- supervisely/versions.json +28 -26
- supervisely/video/sampling.py +39 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/video_annotation/video_object.py +29 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/METADATA +58 -40
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/RECORD +203 -155
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/WHEEL +1 -1
- supervisely_lib/__init__.py +6 -1
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info/licenses}/LICENSE +0 -0
- {supervisely-6.73.438.dist-info → supervisely-6.73.513.dist-info}/top_level.txt +0 -0
|
@@ -71,7 +71,7 @@ class PredictionSession:
|
|
|
71
71
|
tracking_config: dict = None,
|
|
72
72
|
**kwargs: dict,
|
|
73
73
|
):
|
|
74
|
-
|
|
74
|
+
|
|
75
75
|
extra_input_args = ["image_ids", "video_ids", "dataset_ids", "project_ids"]
|
|
76
76
|
assert (
|
|
77
77
|
sum(
|
|
@@ -90,7 +90,6 @@ class PredictionSession:
|
|
|
90
90
|
== 1
|
|
91
91
|
), "Exactly one of input, image_ids, video_id, dataset_id, project_id or image_id must be provided."
|
|
92
92
|
|
|
93
|
-
|
|
94
93
|
self._iterator = None
|
|
95
94
|
self._base_url = url
|
|
96
95
|
self.inference_request_uuid = None
|
|
@@ -115,12 +114,12 @@ class PredictionSession:
|
|
|
115
114
|
self.inference_settings = {
|
|
116
115
|
k: v for k, v in kwargs.items() if isinstance(v, (str, int, float))
|
|
117
116
|
}
|
|
118
|
-
|
|
117
|
+
|
|
119
118
|
if tracking is True:
|
|
120
119
|
model_info = self._get_session_info()
|
|
121
120
|
if not model_info.get("tracking_on_videos_support", False):
|
|
122
121
|
raise ValueError("Tracking is not supported by this model")
|
|
123
|
-
|
|
122
|
+
|
|
124
123
|
if tracking_config is None:
|
|
125
124
|
self.tracker = "botsort"
|
|
126
125
|
self.tracker_settings = {}
|
|
@@ -286,7 +285,7 @@ class PredictionSession:
|
|
|
286
285
|
if self.api is not None:
|
|
287
286
|
return self.api.token
|
|
288
287
|
return env.api_token(raise_not_found=False)
|
|
289
|
-
|
|
288
|
+
|
|
290
289
|
def _get_json_body(self):
|
|
291
290
|
body = {"state": {}, "context": {}}
|
|
292
291
|
if self.inference_request_uuid is not None:
|
|
@@ -298,7 +297,7 @@ class PredictionSession:
|
|
|
298
297
|
if "model_prediction_suffix" in self.kwargs:
|
|
299
298
|
body["state"]["model_prediction_suffix"] = self.kwargs["model_prediction_suffix"]
|
|
300
299
|
return body
|
|
301
|
-
|
|
300
|
+
|
|
302
301
|
def _post(self, method, *args, retries=5, **kwargs) -> requests.Response:
|
|
303
302
|
if kwargs.get("headers") is None:
|
|
304
303
|
kwargs["headers"] = {}
|
|
@@ -336,7 +335,7 @@ class PredictionSession:
|
|
|
336
335
|
method = "get_session_info"
|
|
337
336
|
r = self._post(method, json=self._get_json_body())
|
|
338
337
|
return r.json()
|
|
339
|
-
|
|
338
|
+
|
|
340
339
|
def _get_inference_progress(self):
|
|
341
340
|
method = "get_inference_progress"
|
|
342
341
|
r = self._post(method, json=self._get_json_body())
|
|
@@ -365,9 +364,21 @@ class PredictionSession:
|
|
|
365
364
|
logger.info("Inference request will be cleared on the server")
|
|
366
365
|
return r.json()
|
|
367
366
|
|
|
367
|
+
def _get_final_result(self):
|
|
368
|
+
method = "get_inference_result"
|
|
369
|
+
r = self._post(
|
|
370
|
+
method,
|
|
371
|
+
json=self._get_json_body(),
|
|
372
|
+
)
|
|
373
|
+
return r.json()
|
|
374
|
+
|
|
368
375
|
def _on_infernce_end(self):
|
|
369
376
|
if self.inference_request_uuid is None:
|
|
370
377
|
return
|
|
378
|
+
try:
|
|
379
|
+
self.final_result = self._get_final_result()
|
|
380
|
+
except Exception as e:
|
|
381
|
+
logger.debug("Failed to get final result:", exc_info=True)
|
|
371
382
|
self._clear_inference_request()
|
|
372
383
|
|
|
373
384
|
@property
|
|
@@ -512,18 +523,16 @@ class PredictionSession:
|
|
|
512
523
|
"Inference is already running. Please stop it before starting a new one."
|
|
513
524
|
)
|
|
514
525
|
resp = self._post(method, **kwargs).json()
|
|
515
|
-
|
|
516
526
|
self.inference_request_uuid = resp["inference_request_uuid"]
|
|
517
|
-
|
|
518
|
-
logger.info(
|
|
519
|
-
"Inference has started:",
|
|
520
|
-
extra={"inference_request_uuid": resp.get("inference_request_uuid")},
|
|
521
|
-
)
|
|
522
527
|
try:
|
|
523
528
|
resp, has_started = self._wait_for_inference_start(tqdm=self.tqdm)
|
|
524
529
|
except:
|
|
525
530
|
self.stop()
|
|
526
531
|
raise
|
|
532
|
+
logger.info(
|
|
533
|
+
"Inference has started:",
|
|
534
|
+
extra={"inference_request_uuid": resp.get("inference_request_uuid")},
|
|
535
|
+
)
|
|
527
536
|
frame_iterator = self.Iterator(resp["progress"]["total"], self, tqdm=self.tqdm)
|
|
528
537
|
return frame_iterator
|
|
529
538
|
|
|
@@ -636,8 +645,11 @@ class PredictionSession:
|
|
|
636
645
|
encoder = MultipartEncoder(fields)
|
|
637
646
|
if self.tqdm is not None:
|
|
638
647
|
|
|
648
|
+
bytes_read = 0
|
|
639
649
|
def _callback(monitor):
|
|
640
|
-
|
|
650
|
+
nonlocal bytes_read
|
|
651
|
+
self.tqdm.update(monitor.bytes_read - bytes_read)
|
|
652
|
+
bytes_read = monitor.bytes_read
|
|
641
653
|
|
|
642
654
|
video_size = get_file_size(video_path)
|
|
643
655
|
self._update_progress(self.tqdm, "Uploading video", 0, video_size, is_size=True)
|
supervisely/nn/prediction_dto.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import List, Optional
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from supervisely.geometry.cuboid_3d import Cuboid3d
|
|
6
|
+
from supervisely.geometry.polyline_3d import Polyline3D
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class Prediction:
|
|
@@ -24,10 +25,21 @@ class PredictionMask(Prediction):
|
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class PredictionBBox(Prediction):
|
|
27
|
-
def __init__(self, class_name: str, bbox_tlbr: List[int], score: Optional[float]):
|
|
28
|
+
def __init__(self, class_name: str, bbox_tlbr: List[int], score: Optional[float], angle: Optional[float] = None):
|
|
29
|
+
"""
|
|
30
|
+
:param class_name: Predicted class name.
|
|
31
|
+
:type class_name: str
|
|
32
|
+
:param bbox_tlbr: Bounding box in (top, left, bottom, right) format.
|
|
33
|
+
:type bbox_tlbr: list of 4 ints
|
|
34
|
+
:param score: Confidence score.
|
|
35
|
+
:type score: float, optional
|
|
36
|
+
:param angle: Angle of rotation in radians. Positive values mean clockwise rotation.
|
|
37
|
+
:type angle: int or float, optional
|
|
38
|
+
"""
|
|
28
39
|
super(PredictionBBox, self).__init__(class_name=class_name)
|
|
29
40
|
self.bbox_tlbr = bbox_tlbr
|
|
30
41
|
self.score = score
|
|
42
|
+
self.angle = angle
|
|
31
43
|
|
|
32
44
|
|
|
33
45
|
class PredictionSegmentation(Prediction):
|
|
@@ -81,3 +93,9 @@ class PredictionCuboid3d(Prediction):
|
|
|
81
93
|
super(PredictionCuboid3d, self).__init__(class_name=class_name)
|
|
82
94
|
self.cuboid_3d = cuboid_3d
|
|
83
95
|
self.score = score
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class PredictionPolyline3D(Prediction):
|
|
99
|
+
def __init__(self, class_name: str, polyline_3d: Polyline3D):
|
|
100
|
+
super(PredictionPolyline3D, self).__init__(class_name=class_name)
|
|
101
|
+
self.polyline_3d = polyline_3d
|
|
@@ -36,9 +36,19 @@ class BaseTracker:
|
|
|
36
36
|
def video_annotation(self) -> VideoAnnotation:
|
|
37
37
|
"""Return the accumulated VideoAnnotation."""
|
|
38
38
|
raise NotImplementedError("This method should be overridden by subclasses.")
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def get_default_params(cls) -> Dict[str, Any]:
|
|
42
|
+
"""
|
|
43
|
+
Get default configurable parameters for this tracker.
|
|
44
|
+
Must be implemented in subclass.
|
|
45
|
+
"""
|
|
46
|
+
raise NotImplementedError(
|
|
47
|
+
f"Method get_default_params() must be implemented in {cls.__name__}"
|
|
48
|
+
)
|
|
39
49
|
|
|
40
50
|
def _validate_device(self) -> None:
|
|
41
51
|
if self.device != 'cpu' and not self.device.startswith('cuda'):
|
|
42
52
|
raise ValueError(
|
|
43
53
|
f"Invalid device '{self.device}'. Supported devices are 'cpu' or 'cuda'."
|
|
44
|
-
)
|
|
54
|
+
)
|
|
@@ -1,13 +1,16 @@
|
|
|
1
|
-
import numpy as np
|
|
2
1
|
from collections import deque
|
|
3
2
|
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from supervisely.nn.tracker.botsort.osnet_reid.osnet_reid_interface import (
|
|
6
|
+
OsnetReIDInterface,
|
|
7
|
+
)
|
|
8
|
+
|
|
4
9
|
from . import matching
|
|
5
|
-
from .gmc import GMC
|
|
6
10
|
from .basetrack import BaseTrack, TrackState
|
|
11
|
+
from .gmc import GMC
|
|
7
12
|
from .kalman_filter import KalmanFilter
|
|
8
13
|
|
|
9
|
-
from supervisely.nn.tracker.botsort.osnet_reid.osnet_reid_interface import OsnetReIDInterface
|
|
10
|
-
|
|
11
14
|
|
|
12
15
|
class STrack(BaseTrack):
|
|
13
16
|
|
|
@@ -1,15 +1,16 @@
|
|
|
1
|
-
import
|
|
2
|
-
from supervisely.nn.tracker.base_tracker import BaseTracker
|
|
3
|
-
from supervisely import Annotation, VideoAnnotation
|
|
4
|
-
|
|
1
|
+
import os
|
|
5
2
|
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
6
4
|
from types import SimpleNamespace
|
|
7
|
-
from typing import
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
8
7
|
import numpy as np
|
|
9
8
|
import yaml
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
from supervisely import logger
|
|
9
|
+
|
|
10
|
+
import supervisely as sly
|
|
11
|
+
from supervisely import Annotation, VideoAnnotation, logger
|
|
12
|
+
from supervisely.annotation.label import Label, LabelingStatus
|
|
13
|
+
from supervisely.nn.tracker.base_tracker import BaseTracker
|
|
13
14
|
from supervisely.nn.tracker.botsort.tracker.mc_bot_sort import BoTSORT
|
|
14
15
|
|
|
15
16
|
|
|
@@ -32,10 +33,11 @@ class TrackedObject:
|
|
|
32
33
|
class_name: str
|
|
33
34
|
class_sly_id: Optional[int] # Supervisely class ID
|
|
34
35
|
score: float
|
|
36
|
+
original_label: Optional[sly.Label] = None
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
class BotSortTracker(BaseTracker):
|
|
38
|
-
|
|
40
|
+
|
|
39
41
|
def __init__(self, settings: dict = None, device: str = None):
|
|
40
42
|
super().__init__(settings=settings, device=device)
|
|
41
43
|
|
|
@@ -45,33 +47,30 @@ class BotSortTracker(BaseTracker):
|
|
|
45
47
|
"Tracking dependencies are not installed. "
|
|
46
48
|
"Please install supervisely with `pip install supervisely[tracking]`."
|
|
47
49
|
)
|
|
48
|
-
|
|
50
|
+
|
|
49
51
|
# Load default settings from YAML file
|
|
50
52
|
self.settings = self._load_default_settings()
|
|
51
|
-
|
|
53
|
+
|
|
52
54
|
# Override with user settings if provided
|
|
53
55
|
if settings:
|
|
54
56
|
self.settings.update(settings)
|
|
55
|
-
|
|
57
|
+
|
|
56
58
|
args = SimpleNamespace(**self.settings)
|
|
59
|
+
args.name = "BotSORT"
|
|
57
60
|
args.device = self.device
|
|
58
|
-
|
|
61
|
+
|
|
59
62
|
self.tracker = BoTSORT(args=args)
|
|
60
|
-
|
|
63
|
+
|
|
61
64
|
# State for accumulating results
|
|
62
|
-
self.frame_tracks = []
|
|
65
|
+
self.frame_tracks: List[List[TrackedObject]] = []
|
|
63
66
|
self.obj_classes = {} # class_id -> ObjClass
|
|
64
67
|
self.current_frame = 0
|
|
65
68
|
self.class_ids = {} # class_name -> class_id mapping
|
|
66
69
|
self.frame_shape = ()
|
|
67
70
|
|
|
68
71
|
def _load_default_settings(self) -> dict:
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
config_path = current_dir / "botsort/botsort_config.yaml"
|
|
72
|
-
|
|
73
|
-
with open(config_path, 'r', encoding='utf-8') as file:
|
|
74
|
-
return yaml.safe_load(file)
|
|
72
|
+
"""Internal method: calls classmethod"""
|
|
73
|
+
return self.get_default_params()
|
|
75
74
|
|
|
76
75
|
def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
|
|
77
76
|
"""Update tracker and return list of matches for current frame."""
|
|
@@ -79,12 +78,12 @@ class BotSortTracker(BaseTracker):
|
|
|
79
78
|
self._update_obj_classes(annotation)
|
|
80
79
|
detections = self._convert_annotation(annotation)
|
|
81
80
|
output_stracks, detection_track_map = self.tracker.update(detections, frame)
|
|
82
|
-
tracks = self._stracks_to_tracks(output_stracks, detection_track_map)
|
|
83
|
-
|
|
81
|
+
tracks = self._stracks_to_tracks(output_stracks, detection_track_map, annotation.labels)
|
|
82
|
+
|
|
84
83
|
# Store tracks for VideoAnnotation creation
|
|
85
84
|
self.frame_tracks.append(tracks)
|
|
86
85
|
self.current_frame += 1
|
|
87
|
-
|
|
86
|
+
|
|
88
87
|
matches = []
|
|
89
88
|
for pair in detection_track_map:
|
|
90
89
|
det_id = pair["det_id"]
|
|
@@ -96,9 +95,9 @@ class BotSortTracker(BaseTracker):
|
|
|
96
95
|
"label": annotation.labels[det_id]
|
|
97
96
|
}
|
|
98
97
|
matches.append(match)
|
|
99
|
-
|
|
98
|
+
|
|
100
99
|
return matches
|
|
101
|
-
|
|
100
|
+
|
|
102
101
|
def reset(self) -> None:
|
|
103
102
|
super().reset()
|
|
104
103
|
self.frame_tracks = []
|
|
@@ -109,19 +108,16 @@ class BotSortTracker(BaseTracker):
|
|
|
109
108
|
|
|
110
109
|
def track(self, frames: List[np.ndarray], annotations: List[Annotation]) -> VideoAnnotation:
|
|
111
110
|
"""Track objects through sequence of frames and return VideoAnnotation."""
|
|
112
|
-
if len(frames) != len(annotations):
|
|
113
|
-
raise ValueError("Number of frames and annotations must match")
|
|
114
|
-
|
|
115
111
|
self.reset()
|
|
116
|
-
|
|
112
|
+
|
|
117
113
|
# Process each frame
|
|
118
114
|
for frame_idx, (frame, annotation) in enumerate(zip(frames, annotations)):
|
|
119
115
|
self.current_frame = frame_idx
|
|
120
116
|
self.update(frame, annotation)
|
|
121
|
-
|
|
117
|
+
|
|
122
118
|
# Convert accumulated tracks to VideoAnnotation
|
|
123
|
-
return self.
|
|
124
|
-
|
|
119
|
+
return self.create_video_annotation()
|
|
120
|
+
|
|
125
121
|
def _convert_annotation(self, annotation: Annotation) -> np.ndarray:
|
|
126
122
|
"""Convert Supervisely annotation to BoTSORT detection format."""
|
|
127
123
|
detections_list = []
|
|
@@ -138,10 +134,10 @@ class BotSortTracker(BaseTracker):
|
|
|
138
134
|
)
|
|
139
135
|
|
|
140
136
|
rectangle = label.geometry.to_bbox()
|
|
141
|
-
|
|
137
|
+
|
|
142
138
|
class_name = label.obj_class.name
|
|
143
139
|
class_id = self.class_ids[class_name]
|
|
144
|
-
|
|
140
|
+
|
|
145
141
|
detection = [
|
|
146
142
|
rectangle.left, # x1
|
|
147
143
|
rectangle.top, # y1
|
|
@@ -156,103 +152,127 @@ class BotSortTracker(BaseTracker):
|
|
|
156
152
|
return np.array(detections_list, dtype=np.float32)
|
|
157
153
|
else:
|
|
158
154
|
return np.zeros((0, 6), dtype=np.float32)
|
|
159
|
-
|
|
160
|
-
def _stracks_to_tracks(
|
|
155
|
+
|
|
156
|
+
def _stracks_to_tracks(
|
|
157
|
+
self, output_stracks, detection_track_map, labels: List[Label]
|
|
158
|
+
) -> List[TrackedObject]:
|
|
161
159
|
"""Convert BoTSORT output tracks to TrackedObject dataclass instances."""
|
|
162
160
|
tracks = []
|
|
163
|
-
|
|
161
|
+
|
|
164
162
|
id_to_name = {v: k for k, v in self.class_ids.items()}
|
|
165
|
-
|
|
163
|
+
|
|
166
164
|
track_id_to_det_id = {}
|
|
167
165
|
for pair in detection_track_map:
|
|
168
166
|
det_id = pair["det_id"]
|
|
169
167
|
track_id = pair["track_id"]
|
|
170
168
|
track_id_to_det_id[track_id] = det_id
|
|
171
|
-
|
|
169
|
+
|
|
172
170
|
for strack in output_stracks:
|
|
171
|
+
det_id = track_id_to_det_id.get(strack.track_id)
|
|
172
|
+
if det_id is None:
|
|
173
|
+
continue # Skip tracks without associated detection
|
|
174
|
+
|
|
173
175
|
# BoTSORT may store class info in different attributes
|
|
174
176
|
# Try to get class_id from various possible sources
|
|
175
177
|
class_id = 0 # default
|
|
176
|
-
|
|
178
|
+
|
|
177
179
|
if hasattr(strack, 'cls') and strack.cls != -1:
|
|
178
180
|
# cls should contain the numeric ID we passed in
|
|
179
181
|
class_id = int(strack.cls)
|
|
180
182
|
elif hasattr(strack, 'class_id'):
|
|
181
183
|
class_id = int(strack.class_id)
|
|
182
|
-
|
|
184
|
+
|
|
183
185
|
class_name = id_to_name.get(class_id, "unknown")
|
|
184
|
-
|
|
186
|
+
|
|
185
187
|
# Get Supervisely class ID from stored ObjClass
|
|
186
188
|
class_sly_id = None
|
|
187
189
|
if class_name in self.obj_classes:
|
|
188
190
|
obj_class = self.obj_classes[class_name]
|
|
189
191
|
class_sly_id = obj_class.sly_id
|
|
190
|
-
|
|
192
|
+
|
|
193
|
+
label = labels[det_id]
|
|
191
194
|
track = TrackedObject(
|
|
192
195
|
track_id=strack.track_id,
|
|
193
196
|
det_id=track_id_to_det_id.get(strack.track_id),
|
|
194
197
|
bbox=strack.tlbr.tolist(), # [x1, y1, x2, y2]
|
|
195
198
|
class_name=class_name,
|
|
196
199
|
class_sly_id=class_sly_id,
|
|
197
|
-
score=getattr(strack,
|
|
200
|
+
score=getattr(strack, "score", 1.0),
|
|
201
|
+
original_label=label,
|
|
198
202
|
)
|
|
199
203
|
tracks.append(track)
|
|
200
|
-
|
|
204
|
+
|
|
201
205
|
return tracks
|
|
202
|
-
|
|
206
|
+
|
|
203
207
|
def _update_obj_classes(self, annotation: Annotation):
|
|
204
208
|
"""Extract and store object classes from annotation."""
|
|
205
209
|
for label in annotation.labels:
|
|
206
210
|
class_name = label.obj_class.name
|
|
207
211
|
if class_name not in self.obj_classes:
|
|
208
212
|
self.obj_classes[class_name] = label.obj_class
|
|
209
|
-
|
|
213
|
+
|
|
210
214
|
if class_name not in self.class_ids:
|
|
211
215
|
self.class_ids[class_name] = len(self.class_ids)
|
|
212
216
|
|
|
213
|
-
|
|
214
|
-
|
|
217
|
+
def create_video_annotation(
|
|
218
|
+
self,
|
|
219
|
+
video_frames_count: Optional[int] = None,
|
|
220
|
+
frame_index: Optional[int] = 0,
|
|
221
|
+
step: Optional[int] = 1,
|
|
222
|
+
progress_cb: Optional[Callable[[int], None]] = None,
|
|
223
|
+
) -> VideoAnnotation:
|
|
215
224
|
"""Convert accumulated tracking results to Supervisely VideoAnnotation."""
|
|
216
225
|
img_h, img_w = self.frame_shape
|
|
217
226
|
video_objects = {} # track_id -> VideoObject
|
|
218
227
|
frames = []
|
|
219
|
-
|
|
220
|
-
|
|
228
|
+
if video_frames_count is None:
|
|
229
|
+
video_frames_count = len(self.frame_tracks)
|
|
230
|
+
|
|
231
|
+
for i, tracks in enumerate(self.frame_tracks, frame_index):
|
|
232
|
+
frame_idx = frame_index + i * step
|
|
221
233
|
frame_figures = []
|
|
222
|
-
|
|
234
|
+
|
|
223
235
|
for track in tracks:
|
|
224
236
|
track_id = track.track_id
|
|
225
237
|
bbox = track.bbox # [x1, y1, x2, y2]
|
|
226
238
|
class_name = track.class_name
|
|
227
|
-
|
|
239
|
+
|
|
228
240
|
# Clip bbox to image boundaries
|
|
229
241
|
x1, y1, x2, y2 = bbox
|
|
230
242
|
dims = np.array([img_w, img_h, img_w, img_h]) - 1
|
|
231
243
|
x1, y1, x2, y2 = np.clip([x1, y1, x2, y2], 0, dims)
|
|
232
|
-
|
|
244
|
+
|
|
233
245
|
# Get or create VideoObject
|
|
234
246
|
if track_id not in video_objects:
|
|
235
247
|
obj_class = self.obj_classes.get(class_name)
|
|
236
248
|
if obj_class is None:
|
|
237
249
|
continue # Skip if class not found
|
|
238
250
|
video_objects[track_id] = sly.VideoObject(obj_class)
|
|
239
|
-
|
|
251
|
+
|
|
240
252
|
video_object = video_objects[track_id]
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
253
|
+
frame_figures.append(
|
|
254
|
+
sly.VideoFigure(
|
|
255
|
+
video_object,
|
|
256
|
+
track.original_label.geometry,
|
|
257
|
+
frame_idx,
|
|
258
|
+
track_id=str(track_id),
|
|
259
|
+
status=LabelingStatus.AUTO,
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
244
263
|
frames.append(sly.Frame(frame_idx, frame_figures))
|
|
264
|
+
if progress_cb is not None:
|
|
265
|
+
progress_cb()
|
|
245
266
|
|
|
246
267
|
objects = list(video_objects.values())
|
|
247
268
|
|
|
248
|
-
|
|
249
269
|
return VideoAnnotation(
|
|
250
270
|
img_size=self.frame_shape,
|
|
251
|
-
frames_count=
|
|
271
|
+
frames_count=video_frames_count,
|
|
252
272
|
objects=sly.VideoObjectCollection(objects),
|
|
253
|
-
frames=sly.FrameCollection(frames)
|
|
273
|
+
frames=sly.FrameCollection(frames),
|
|
254
274
|
)
|
|
255
|
-
|
|
275
|
+
|
|
256
276
|
@property
|
|
257
277
|
def video_annotation(self) -> VideoAnnotation:
|
|
258
278
|
"""Return the accumulated VideoAnnotation."""
|
|
@@ -262,5 +282,14 @@ class BotSortTracker(BaseTracker):
|
|
|
262
282
|
"Please run tracking first using track() method or process frames with update()."
|
|
263
283
|
)
|
|
264
284
|
raise ValueError(error_msg)
|
|
265
|
-
|
|
266
|
-
return self.
|
|
285
|
+
|
|
286
|
+
return self.create_video_annotation()
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def get_default_params(cls) -> Dict[str, Any]:
|
|
290
|
+
"""Public API: get default params WITHOUT creating instance."""
|
|
291
|
+
current_dir = Path(__file__).parent
|
|
292
|
+
config_path = current_dir / "botsort/botsort_config.yaml"
|
|
293
|
+
|
|
294
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
|
295
|
+
return yaml.safe_load(file)
|
supervisely/nn/tracker/utils.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
1
|
from typing import List, Union, Dict, Tuple
|
|
3
2
|
from pathlib import Path
|
|
4
3
|
from collections import defaultdict
|
|
@@ -6,6 +5,7 @@ import numpy as np
|
|
|
6
5
|
|
|
7
6
|
import supervisely as sly
|
|
8
7
|
from supervisely.nn.model.prediction import Prediction
|
|
8
|
+
from supervisely.annotation.label import LabelingStatus
|
|
9
9
|
from supervisely import VideoAnnotation
|
|
10
10
|
from supervisely import logger
|
|
11
11
|
|
|
@@ -73,12 +73,11 @@ def predictions_to_video_annotation(
|
|
|
73
73
|
|
|
74
74
|
video_object = video_objects[track_id]
|
|
75
75
|
rect = sly.Rectangle(top=top, left=left, bottom=bottom, right=right)
|
|
76
|
-
frame_figures.append(sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id)))
|
|
76
|
+
frame_figures.append(sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id), status=LabelingStatus.AUTO))
|
|
77
77
|
|
|
78
78
|
frames.append(sly.Frame(frame_idx, frame_figures))
|
|
79
79
|
|
|
80
|
-
objects = list(video_objects.values())
|
|
81
|
-
|
|
80
|
+
objects = list(video_objects.values())
|
|
82
81
|
return VideoAnnotation(
|
|
83
82
|
img_size=frame_shape,
|
|
84
83
|
frames_count=len(predictions),
|
|
@@ -271,4 +270,4 @@ def mot_to_video_annotation(
|
|
|
271
270
|
|
|
272
271
|
logger.info(f"Created VideoAnnotation with {len(objects)} tracks and {frames_count} frames")
|
|
273
272
|
|
|
274
|
-
return annotation
|
|
273
|
+
return annotation
|