supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of supervisely might be problematic. Click here for more details.
- supervisely/__init__.py +136 -1
- supervisely/_utils.py +81 -0
- supervisely/annotation/json_geometries_map.py +2 -0
- supervisely/annotation/label.py +80 -3
- supervisely/api/annotation_api.py +9 -9
- supervisely/api/api.py +67 -43
- supervisely/api/app_api.py +72 -5
- supervisely/api/dataset_api.py +108 -33
- supervisely/api/entity_annotation/figure_api.py +113 -49
- supervisely/api/image_api.py +82 -0
- supervisely/api/module_api.py +10 -0
- supervisely/api/nn/deploy_api.py +15 -9
- supervisely/api/nn/ecosystem_models_api.py +201 -0
- supervisely/api/nn/neural_network_api.py +12 -3
- supervisely/api/pointcloud/pointcloud_api.py +38 -0
- supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
- supervisely/api/project_api.py +213 -6
- supervisely/api/task_api.py +11 -1
- supervisely/api/video/video_annotation_api.py +4 -2
- supervisely/api/video/video_api.py +79 -1
- supervisely/api/video/video_figure_api.py +24 -11
- supervisely/api/volume/volume_api.py +38 -0
- supervisely/app/__init__.py +1 -1
- supervisely/app/content.py +14 -6
- supervisely/app/fastapi/__init__.py +1 -0
- supervisely/app/fastapi/custom_static_files.py +1 -1
- supervisely/app/fastapi/multi_user.py +88 -0
- supervisely/app/fastapi/subapp.py +175 -42
- supervisely/app/fastapi/templating.py +1 -1
- supervisely/app/fastapi/websocket.py +77 -9
- supervisely/app/singleton.py +21 -0
- supervisely/app/v1/app_service.py +18 -2
- supervisely/app/v1/constants.py +7 -1
- supervisely/app/widgets/__init__.py +11 -1
- supervisely/app/widgets/agent_selector/template.html +1 -0
- supervisely/app/widgets/card/card.py +20 -0
- supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
- supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
- supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
- supervisely/app/widgets/dialog/dialog.py +12 -0
- supervisely/app/widgets/dialog/template.html +2 -1
- supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
- supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
- supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
- supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
- supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
- supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
- supervisely/app/widgets/fast_table/fast_table.py +713 -126
- supervisely/app/widgets/fast_table/script.js +492 -95
- supervisely/app/widgets/fast_table/style.css +54 -0
- supervisely/app/widgets/fast_table/template.html +45 -5
- supervisely/app/widgets/heatmap/__init__.py +0 -0
- supervisely/app/widgets/heatmap/heatmap.py +523 -0
- supervisely/app/widgets/heatmap/script.js +378 -0
- supervisely/app/widgets/heatmap/style.css +227 -0
- supervisely/app/widgets/heatmap/template.html +21 -0
- supervisely/app/widgets/input_tag/input_tag.py +102 -15
- supervisely/app/widgets/input_tag_list/__init__.py +0 -0
- supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
- supervisely/app/widgets/input_tag_list/template.html +70 -0
- supervisely/app/widgets/radio_table/radio_table.py +10 -2
- supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
- supervisely/app/widgets/radio_tabs/template.html +1 -0
- supervisely/app/widgets/select/select.py +6 -4
- supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
- supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
- supervisely/app/widgets/table/table.py +68 -13
- supervisely/app/widgets/tabs/tabs.py +22 -6
- supervisely/app/widgets/tabs/template.html +5 -1
- supervisely/app/widgets/transfer/style.css +3 -0
- supervisely/app/widgets/transfer/template.html +3 -1
- supervisely/app/widgets/transfer/transfer.py +48 -45
- supervisely/app/widgets/tree_select/tree_select.py +2 -0
- supervisely/convert/image/csv/csv_converter.py +24 -15
- supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
- supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
- supervisely/convert/video/video_converter.py +2 -2
- supervisely/geometry/polyline_3d.py +110 -0
- supervisely/io/env.py +161 -1
- supervisely/nn/artifacts/__init__.py +1 -1
- supervisely/nn/artifacts/artifacts.py +10 -2
- supervisely/nn/artifacts/detectron2.py +1 -0
- supervisely/nn/artifacts/hrda.py +1 -0
- supervisely/nn/artifacts/mmclassification.py +20 -0
- supervisely/nn/artifacts/mmdetection.py +5 -3
- supervisely/nn/artifacts/mmsegmentation.py +1 -0
- supervisely/nn/artifacts/ritm.py +1 -0
- supervisely/nn/artifacts/rtdetr.py +1 -0
- supervisely/nn/artifacts/unet.py +1 -0
- supervisely/nn/artifacts/utils.py +3 -0
- supervisely/nn/artifacts/yolov5.py +2 -0
- supervisely/nn/artifacts/yolov8.py +1 -0
- supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
- supervisely/nn/experiments.py +9 -0
- supervisely/nn/inference/cache.py +37 -17
- supervisely/nn/inference/gui/serving_gui_template.py +39 -13
- supervisely/nn/inference/inference.py +953 -211
- supervisely/nn/inference/inference_request.py +15 -8
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/predict_app/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
- supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
- supervisely/nn/inference/predict_app/gui/gui.py +915 -0
- supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
- supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
- supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
- supervisely/nn/inference/predict_app/gui/preview.py +93 -0
- supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
- supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
- supervisely/nn/inference/predict_app/gui/utils.py +399 -0
- supervisely/nn/inference/predict_app/predict_app.py +176 -0
- supervisely/nn/inference/session.py +47 -39
- supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
- supervisely/nn/inference/tracking/point_tracking.py +5 -1
- supervisely/nn/inference/tracking/tracker_interface.py +4 -0
- supervisely/nn/inference/uploader.py +9 -5
- supervisely/nn/model/model_api.py +44 -22
- supervisely/nn/model/prediction.py +15 -1
- supervisely/nn/model/prediction_session.py +70 -14
- supervisely/nn/prediction_dto.py +7 -0
- supervisely/nn/tracker/__init__.py +6 -8
- supervisely/nn/tracker/base_tracker.py +54 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
- supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +273 -0
- supervisely/nn/tracker/calculate_metrics.py +264 -0
- supervisely/nn/tracker/utils.py +273 -0
- supervisely/nn/tracker/visualize.py +520 -0
- supervisely/nn/training/gui/gui.py +152 -49
- supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
- supervisely/nn/training/gui/model_selector.py +8 -6
- supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
- supervisely/nn/training/gui/training_artifacts.py +3 -1
- supervisely/nn/training/train_app.py +225 -46
- supervisely/project/pointcloud_episode_project.py +12 -8
- supervisely/project/pointcloud_project.py +12 -8
- supervisely/project/project.py +221 -75
- supervisely/template/experiment/experiment.html.jinja +105 -55
- supervisely/template/experiment/experiment_generator.py +258 -112
- supervisely/template/experiment/header.html.jinja +31 -13
- supervisely/template/experiment/sly-style.css +7 -2
- supervisely/versions.json +3 -1
- supervisely/video/sampling.py +42 -20
- supervisely/video/video.py +41 -12
- supervisely/video_annotation/video_figure.py +38 -4
- supervisely/volume/stl_converter.py +2 -0
- supervisely/worker_api/agent_rpc.py +24 -1
- supervisely/worker_api/rpc_servicer.py +31 -7
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
- supervisely_lib/__init__.py +6 -1
- supervisely/app/widgets/experiment_selector/style.css +0 -27
- supervisely/app/widgets/experiment_selector/template.html +0 -61
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import supervisely as sly
|
|
2
|
+
from supervisely.nn.tracker.base_tracker import BaseTracker
|
|
3
|
+
from supervisely import Annotation, VideoAnnotation
|
|
4
|
+
from supervisely.annotation.label import LabelingStatus
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from types import SimpleNamespace
|
|
7
|
+
from typing import List, Dict, Tuple, Any, Optional
|
|
8
|
+
import numpy as np
|
|
9
|
+
import yaml
|
|
10
|
+
import os
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from supervisely import logger
|
|
13
|
+
from supervisely.nn.tracker.botsort.tracker.mc_bot_sort import BoTSORT
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class TrackedObject:
|
|
18
|
+
"""
|
|
19
|
+
Data class representing a tracked object in a single frame.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
track_id: Unique identifier for the track
|
|
23
|
+
det_id: Detection ID for mapping back to original annotation
|
|
24
|
+
bbox: Bounding box coordinates in format [x1, y1, x2, y2]
|
|
25
|
+
class_name: String class name
|
|
26
|
+
class_sly_id: Supervisely class ID (from ObjClass.sly_id)
|
|
27
|
+
score: Confidence score of the detection/track
|
|
28
|
+
"""
|
|
29
|
+
track_id: int
|
|
30
|
+
det_id: int
|
|
31
|
+
bbox: List[float] # [x1, y1, x2, y2]
|
|
32
|
+
class_name: str
|
|
33
|
+
class_sly_id: Optional[int] # Supervisely class ID
|
|
34
|
+
score: float
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BotSortTracker(BaseTracker):
|
|
38
|
+
|
|
39
|
+
def __init__(self, settings: dict = None, device: str = None):
|
|
40
|
+
super().__init__(settings=settings, device=device)
|
|
41
|
+
|
|
42
|
+
from supervisely.nn.tracker import TRACKING_LIBS_INSTALLED
|
|
43
|
+
if not TRACKING_LIBS_INSTALLED:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"Tracking dependencies are not installed. "
|
|
46
|
+
"Please install supervisely with `pip install supervisely[tracking]`."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Load default settings from YAML file
|
|
50
|
+
self.settings = self._load_default_settings()
|
|
51
|
+
|
|
52
|
+
# Override with user settings if provided
|
|
53
|
+
if settings:
|
|
54
|
+
self.settings.update(settings)
|
|
55
|
+
|
|
56
|
+
args = SimpleNamespace(**self.settings)
|
|
57
|
+
args.name = "BotSORT"
|
|
58
|
+
args.device = self.device
|
|
59
|
+
|
|
60
|
+
self.tracker = BoTSORT(args=args)
|
|
61
|
+
|
|
62
|
+
# State for accumulating results
|
|
63
|
+
self.frame_tracks = []
|
|
64
|
+
self.obj_classes = {} # class_id -> ObjClass
|
|
65
|
+
self.current_frame = 0
|
|
66
|
+
self.class_ids = {} # class_name -> class_id mapping
|
|
67
|
+
self.frame_shape = ()
|
|
68
|
+
|
|
69
|
+
def _load_default_settings(self) -> dict:
|
|
70
|
+
"""Internal method: calls classmethod"""
|
|
71
|
+
return self.get_default_params()
|
|
72
|
+
|
|
73
|
+
def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
|
|
74
|
+
"""Update tracker and return list of matches for current frame."""
|
|
75
|
+
self.frame_shape = frame.shape[:2]
|
|
76
|
+
self._update_obj_classes(annotation)
|
|
77
|
+
detections = self._convert_annotation(annotation)
|
|
78
|
+
output_stracks, detection_track_map = self.tracker.update(detections, frame)
|
|
79
|
+
tracks = self._stracks_to_tracks(output_stracks, detection_track_map)
|
|
80
|
+
|
|
81
|
+
# Store tracks for VideoAnnotation creation
|
|
82
|
+
self.frame_tracks.append(tracks)
|
|
83
|
+
self.current_frame += 1
|
|
84
|
+
|
|
85
|
+
matches = []
|
|
86
|
+
for pair in detection_track_map:
|
|
87
|
+
det_id = pair["det_id"]
|
|
88
|
+
track_id = pair["track_id"]
|
|
89
|
+
|
|
90
|
+
if track_id is not None:
|
|
91
|
+
match = {
|
|
92
|
+
"track_id": track_id,
|
|
93
|
+
"label": annotation.labels[det_id]
|
|
94
|
+
}
|
|
95
|
+
matches.append(match)
|
|
96
|
+
|
|
97
|
+
return matches
|
|
98
|
+
|
|
99
|
+
def reset(self) -> None:
|
|
100
|
+
super().reset()
|
|
101
|
+
self.frame_tracks = []
|
|
102
|
+
self.obj_classes = {}
|
|
103
|
+
self.current_frame = 0
|
|
104
|
+
self.class_ids = {}
|
|
105
|
+
self.frame_shape = ()
|
|
106
|
+
|
|
107
|
+
def track(self, frames: List[np.ndarray], annotations: List[Annotation]) -> VideoAnnotation:
|
|
108
|
+
"""Track objects through sequence of frames and return VideoAnnotation."""
|
|
109
|
+
if len(frames) != len(annotations):
|
|
110
|
+
raise ValueError("Number of frames and annotations must match")
|
|
111
|
+
|
|
112
|
+
self.reset()
|
|
113
|
+
|
|
114
|
+
# Process each frame
|
|
115
|
+
for frame_idx, (frame, annotation) in enumerate(zip(frames, annotations)):
|
|
116
|
+
self.current_frame = frame_idx
|
|
117
|
+
self.update(frame, annotation)
|
|
118
|
+
|
|
119
|
+
# Convert accumulated tracks to VideoAnnotation
|
|
120
|
+
return self._create_video_annotation()
|
|
121
|
+
|
|
122
|
+
def _convert_annotation(self, annotation: Annotation) -> np.ndarray:
|
|
123
|
+
"""Convert Supervisely annotation to BoTSORT detection format."""
|
|
124
|
+
detections_list = []
|
|
125
|
+
|
|
126
|
+
for label in annotation.labels:
|
|
127
|
+
if label.tags.get("confidence", None) is not None:
|
|
128
|
+
confidence = label.tags.get("confidence").value
|
|
129
|
+
elif label.tags.get("conf", None) is not None:
|
|
130
|
+
confidence = label.tags.get("conf").value
|
|
131
|
+
else:
|
|
132
|
+
confidence = 1.0
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"Label {label.obj_class.name} does not have confidence tag, using default value 1.0"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
rectangle = label.geometry.to_bbox()
|
|
138
|
+
|
|
139
|
+
class_name = label.obj_class.name
|
|
140
|
+
class_id = self.class_ids[class_name]
|
|
141
|
+
|
|
142
|
+
detection = [
|
|
143
|
+
rectangle.left, # x1
|
|
144
|
+
rectangle.top, # y1
|
|
145
|
+
rectangle.right, # x2
|
|
146
|
+
rectangle.bottom, # y2
|
|
147
|
+
confidence, # score
|
|
148
|
+
class_id, # class_id as number
|
|
149
|
+
]
|
|
150
|
+
detections_list.append(detection)
|
|
151
|
+
|
|
152
|
+
if detections_list:
|
|
153
|
+
return np.array(detections_list, dtype=np.float32)
|
|
154
|
+
else:
|
|
155
|
+
return np.zeros((0, 6), dtype=np.float32)
|
|
156
|
+
|
|
157
|
+
def _stracks_to_tracks(self, output_stracks, detection_track_map) -> List[TrackedObject]:
|
|
158
|
+
"""Convert BoTSORT output tracks to TrackedObject dataclass instances."""
|
|
159
|
+
tracks = []
|
|
160
|
+
|
|
161
|
+
id_to_name = {v: k for k, v in self.class_ids.items()}
|
|
162
|
+
|
|
163
|
+
track_id_to_det_id = {}
|
|
164
|
+
for pair in detection_track_map:
|
|
165
|
+
det_id = pair["det_id"]
|
|
166
|
+
track_id = pair["track_id"]
|
|
167
|
+
track_id_to_det_id[track_id] = det_id
|
|
168
|
+
|
|
169
|
+
for strack in output_stracks:
|
|
170
|
+
# BoTSORT may store class info in different attributes
|
|
171
|
+
# Try to get class_id from various possible sources
|
|
172
|
+
class_id = 0 # default
|
|
173
|
+
|
|
174
|
+
if hasattr(strack, 'cls') and strack.cls != -1:
|
|
175
|
+
# cls should contain the numeric ID we passed in
|
|
176
|
+
class_id = int(strack.cls)
|
|
177
|
+
elif hasattr(strack, 'class_id'):
|
|
178
|
+
class_id = int(strack.class_id)
|
|
179
|
+
|
|
180
|
+
class_name = id_to_name.get(class_id, "unknown")
|
|
181
|
+
|
|
182
|
+
# Get Supervisely class ID from stored ObjClass
|
|
183
|
+
class_sly_id = None
|
|
184
|
+
if class_name in self.obj_classes:
|
|
185
|
+
obj_class = self.obj_classes[class_name]
|
|
186
|
+
class_sly_id = obj_class.sly_id
|
|
187
|
+
|
|
188
|
+
track = TrackedObject(
|
|
189
|
+
track_id=strack.track_id,
|
|
190
|
+
det_id=track_id_to_det_id.get(strack.track_id),
|
|
191
|
+
bbox=strack.tlbr.tolist(), # [x1, y1, x2, y2]
|
|
192
|
+
class_name=class_name,
|
|
193
|
+
class_sly_id=class_sly_id,
|
|
194
|
+
score=getattr(strack, 'score', 1.0)
|
|
195
|
+
)
|
|
196
|
+
tracks.append(track)
|
|
197
|
+
|
|
198
|
+
return tracks
|
|
199
|
+
|
|
200
|
+
def _update_obj_classes(self, annotation: Annotation):
|
|
201
|
+
"""Extract and store object classes from annotation."""
|
|
202
|
+
for label in annotation.labels:
|
|
203
|
+
class_name = label.obj_class.name
|
|
204
|
+
if class_name not in self.obj_classes:
|
|
205
|
+
self.obj_classes[class_name] = label.obj_class
|
|
206
|
+
|
|
207
|
+
if class_name not in self.class_ids:
|
|
208
|
+
self.class_ids[class_name] = len(self.class_ids)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _create_video_annotation(self) -> VideoAnnotation:
|
|
212
|
+
"""Convert accumulated tracking results to Supervisely VideoAnnotation."""
|
|
213
|
+
img_h, img_w = self.frame_shape
|
|
214
|
+
video_objects = {} # track_id -> VideoObject
|
|
215
|
+
frames = []
|
|
216
|
+
|
|
217
|
+
for frame_idx, tracks in enumerate(self.frame_tracks):
|
|
218
|
+
frame_figures = []
|
|
219
|
+
|
|
220
|
+
for track in tracks:
|
|
221
|
+
track_id = track.track_id
|
|
222
|
+
bbox = track.bbox # [x1, y1, x2, y2]
|
|
223
|
+
class_name = track.class_name
|
|
224
|
+
|
|
225
|
+
# Clip bbox to image boundaries
|
|
226
|
+
x1, y1, x2, y2 = bbox
|
|
227
|
+
dims = np.array([img_w, img_h, img_w, img_h]) - 1
|
|
228
|
+
x1, y1, x2, y2 = np.clip([x1, y1, x2, y2], 0, dims)
|
|
229
|
+
|
|
230
|
+
# Get or create VideoObject
|
|
231
|
+
if track_id not in video_objects:
|
|
232
|
+
obj_class = self.obj_classes.get(class_name)
|
|
233
|
+
if obj_class is None:
|
|
234
|
+
continue # Skip if class not found
|
|
235
|
+
video_objects[track_id] = sly.VideoObject(obj_class)
|
|
236
|
+
|
|
237
|
+
video_object = video_objects[track_id]
|
|
238
|
+
rect = sly.Rectangle(top=y1, left=x1, bottom=y2, right=x2)
|
|
239
|
+
frame_figures.append(sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id), status=LabelingStatus.AUTO))
|
|
240
|
+
|
|
241
|
+
frames.append(sly.Frame(frame_idx, frame_figures))
|
|
242
|
+
|
|
243
|
+
objects = list(video_objects.values())
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
return VideoAnnotation(
|
|
247
|
+
img_size=self.frame_shape,
|
|
248
|
+
frames_count=len(self.frame_tracks),
|
|
249
|
+
objects=sly.VideoObjectCollection(objects),
|
|
250
|
+
frames=sly.FrameCollection(frames)
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def video_annotation(self) -> VideoAnnotation:
|
|
255
|
+
"""Return the accumulated VideoAnnotation."""
|
|
256
|
+
if not self.frame_tracks:
|
|
257
|
+
error_msg = (
|
|
258
|
+
"No tracking data available. "
|
|
259
|
+
"Please run tracking first using track() method or process frames with update()."
|
|
260
|
+
)
|
|
261
|
+
raise ValueError(error_msg)
|
|
262
|
+
|
|
263
|
+
return self._create_video_annotation()
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def get_default_params(cls) -> Dict[str, Any]:
|
|
267
|
+
"""Public API: get default params WITHOUT creating instance."""
|
|
268
|
+
current_dir = Path(__file__).parent
|
|
269
|
+
config_path = current_dir / "botsort/botsort_config.yaml"
|
|
270
|
+
|
|
271
|
+
with open(config_path, 'r', encoding='utf-8') as file:
|
|
272
|
+
return yaml.safe_load(file)
|
|
273
|
+
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Dict, List, Union
|
|
4
|
+
|
|
5
|
+
from scipy.optimize import linear_sum_assignment # pylint: disable=import-error
|
|
6
|
+
|
|
7
|
+
import supervisely as sly
|
|
8
|
+
from supervisely.video_annotation.video_annotation import VideoAnnotation
|
|
9
|
+
|
|
10
|
+
import motmetrics as mm # pylint: disable=import-error
|
|
11
|
+
|
|
12
|
+
class TrackingEvaluator:
|
|
13
|
+
"""
|
|
14
|
+
Evaluator for video tracking metrics including MOTA, MOTP, IDF1.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, iou_threshold: float = 0.5):
|
|
18
|
+
"""Initialize evaluator with IoU threshold for matching."""
|
|
19
|
+
from supervisely.nn.tracker import TRACKING_LIBS_INSTALLED
|
|
20
|
+
if not TRACKING_LIBS_INSTALLED:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"Tracking dependencies are not installed. "
|
|
23
|
+
"Please install supervisely with `pip install supervisely[tracking]`."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if not 0.0 <= iou_threshold <= 1.0:
|
|
27
|
+
raise ValueError("iou_threshold must be in [0.0, 1.0]")
|
|
28
|
+
self.iou_threshold = iou_threshold
|
|
29
|
+
|
|
30
|
+
def evaluate(
|
|
31
|
+
self,
|
|
32
|
+
gt_annotation: VideoAnnotation,
|
|
33
|
+
pred_annotation: VideoAnnotation,
|
|
34
|
+
) -> Dict[str, Union[float, int]]:
|
|
35
|
+
"""Main entry: extract tracks from annotations, compute basic and MOT metrics, return results."""
|
|
36
|
+
self._validate_annotations(gt_annotation, pred_annotation)
|
|
37
|
+
self.img_height, self.img_width = gt_annotation.img_size
|
|
38
|
+
|
|
39
|
+
gt_tracks = self._extract_tracks(gt_annotation)
|
|
40
|
+
pred_tracks = self._extract_tracks(pred_annotation)
|
|
41
|
+
|
|
42
|
+
basic = self._compute_basic_metrics(gt_tracks, pred_tracks)
|
|
43
|
+
mot = self._compute_mot_metrics(gt_tracks, pred_tracks)
|
|
44
|
+
|
|
45
|
+
results = {
|
|
46
|
+
# basic detection
|
|
47
|
+
"precision": basic["precision"],
|
|
48
|
+
"recall": basic["recall"],
|
|
49
|
+
"f1": basic["f1"],
|
|
50
|
+
"avg_iou": basic["avg_iou"],
|
|
51
|
+
"true_positives": basic["tp"],
|
|
52
|
+
"false_positives": basic["fp"],
|
|
53
|
+
"false_negatives": basic["fn"],
|
|
54
|
+
"total_gt_objects": basic["total_gt"],
|
|
55
|
+
"total_pred_objects": basic["total_pred"],
|
|
56
|
+
|
|
57
|
+
# motmetrics
|
|
58
|
+
"mota": mot["mota"],
|
|
59
|
+
"motp": mot["motp"],
|
|
60
|
+
"idf1": mot["idf1"],
|
|
61
|
+
"id_switches": mot["id_switches"],
|
|
62
|
+
"fragmentations": mot["fragmentations"],
|
|
63
|
+
"num_misses": mot["num_misses"],
|
|
64
|
+
"num_false_positives": mot["num_false_positives"],
|
|
65
|
+
|
|
66
|
+
# config
|
|
67
|
+
"iou_threshold": self.iou_threshold,
|
|
68
|
+
}
|
|
69
|
+
return results
|
|
70
|
+
|
|
71
|
+
def _validate_annotations(self, gt: VideoAnnotation, pred: VideoAnnotation):
|
|
72
|
+
"""Minimal type validation for annotations."""
|
|
73
|
+
if not isinstance(gt, VideoAnnotation) or not isinstance(pred, VideoAnnotation):
|
|
74
|
+
raise TypeError("gt_annotation and pred_annotation must be VideoAnnotation instances")
|
|
75
|
+
|
|
76
|
+
def _extract_tracks(self, annotation: VideoAnnotation) -> Dict[int, List[Dict]]:
|
|
77
|
+
"""
|
|
78
|
+
Extract tracks from a VideoAnnotation into a dict keyed by frame index.
|
|
79
|
+
Each element is a dict: {'track_id': int, 'bbox': [x1,y1,x2,y2], 'confidence': float, 'class_name': str}
|
|
80
|
+
"""
|
|
81
|
+
frames_to_tracks = defaultdict(list)
|
|
82
|
+
|
|
83
|
+
for frame in annotation.frames:
|
|
84
|
+
frame_idx = frame.index
|
|
85
|
+
for figure in frame.figures:
|
|
86
|
+
# use track_id if present, otherwise fallback to object's key int
|
|
87
|
+
track_id = int(figure.track_id) if figure.track_id is not None else figure.video_object.key().int
|
|
88
|
+
|
|
89
|
+
bbox = figure.geometry
|
|
90
|
+
if not isinstance(bbox, sly.Rectangle):
|
|
91
|
+
bbox = bbox.to_bbox()
|
|
92
|
+
|
|
93
|
+
x1 = float(bbox.left)
|
|
94
|
+
y1 = float(bbox.top)
|
|
95
|
+
x2 = float(bbox.right)
|
|
96
|
+
y2 = float(bbox.bottom)
|
|
97
|
+
|
|
98
|
+
frames_to_tracks[frame_idx].append({
|
|
99
|
+
"track_id": track_id,
|
|
100
|
+
"bbox": [x1, y1, x2, y2],
|
|
101
|
+
"confidence": float(getattr(figure, "confidence", 1.0)),
|
|
102
|
+
"class_name": figure.video_object.obj_class.name
|
|
103
|
+
})
|
|
104
|
+
|
|
105
|
+
return dict(frames_to_tracks)
|
|
106
|
+
|
|
107
|
+
def _compute_basic_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
|
|
108
|
+
"""
|
|
109
|
+
Compute per-frame true positives / false positives / false negatives and average IoU.
|
|
110
|
+
Matching is performed with Hungarian algorithm (scipy). Matches with IoU < threshold are discarded.
|
|
111
|
+
"""
|
|
112
|
+
tp = fp = fn = 0
|
|
113
|
+
total_iou = 0.0
|
|
114
|
+
iou_count = 0
|
|
115
|
+
|
|
116
|
+
frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
|
|
117
|
+
for f in frames:
|
|
118
|
+
gts = gt_tracks.get(f, [])
|
|
119
|
+
preds = pred_tracks.get(f, [])
|
|
120
|
+
|
|
121
|
+
if not gts and not preds:
|
|
122
|
+
continue
|
|
123
|
+
if not gts:
|
|
124
|
+
fp += len(preds)
|
|
125
|
+
continue
|
|
126
|
+
if not preds:
|
|
127
|
+
fn += len(gts)
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
gt_boxes = np.array([g["bbox"] for g in gts])
|
|
131
|
+
pred_boxes = np.array([p["bbox"] for p in preds])
|
|
132
|
+
|
|
133
|
+
# get cost matrix from motmetrics (cost = 1 - IoU)
|
|
134
|
+
cost_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
|
|
135
|
+
# replace NaNs (if any) with a large cost so Hungarian will avoid them
|
|
136
|
+
cost_for_assignment = np.where(np.isnan(cost_mat), 1e6, cost_mat)
|
|
137
|
+
|
|
138
|
+
# Hungarian assignment (minimize cost -> maximize IoU)
|
|
139
|
+
row_idx, col_idx = linear_sum_assignment(cost_for_assignment)
|
|
140
|
+
|
|
141
|
+
matched_gt = set()
|
|
142
|
+
matched_pred = set()
|
|
143
|
+
for r, c in zip(row_idx, col_idx):
|
|
144
|
+
if r < cost_mat.shape[0] and c < cost_mat.shape[1]:
|
|
145
|
+
# IoU = 1 - cost
|
|
146
|
+
cost_val = cost_mat[r, c]
|
|
147
|
+
if np.isnan(cost_val):
|
|
148
|
+
continue
|
|
149
|
+
iou_val = 1.0 - float(cost_val)
|
|
150
|
+
if iou_val >= self.iou_threshold:
|
|
151
|
+
matched_gt.add(r)
|
|
152
|
+
matched_pred.add(c)
|
|
153
|
+
total_iou += iou_val
|
|
154
|
+
iou_count += 1
|
|
155
|
+
|
|
156
|
+
frame_tp = len(matched_gt)
|
|
157
|
+
frame_fp = len(preds) - len(matched_pred)
|
|
158
|
+
frame_fn = len(gts) - len(matched_gt)
|
|
159
|
+
|
|
160
|
+
tp += frame_tp
|
|
161
|
+
fp += frame_fp
|
|
162
|
+
fn += frame_fn
|
|
163
|
+
|
|
164
|
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
|
165
|
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
|
166
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
|
167
|
+
avg_iou = total_iou / iou_count if iou_count > 0 else 0.0
|
|
168
|
+
|
|
169
|
+
total_gt = sum(len(v) for v in gt_tracks.values())
|
|
170
|
+
total_pred = sum(len(v) for v in pred_tracks.values())
|
|
171
|
+
|
|
172
|
+
return {
|
|
173
|
+
"precision": precision,
|
|
174
|
+
"recall": recall,
|
|
175
|
+
"f1": f1,
|
|
176
|
+
"avg_iou": avg_iou,
|
|
177
|
+
"tp": tp,
|
|
178
|
+
"fp": fp,
|
|
179
|
+
"fn": fn,
|
|
180
|
+
"total_gt": total_gt,
|
|
181
|
+
"total_pred": total_pred,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def _compute_mot_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
|
|
185
|
+
"""
|
|
186
|
+
Use motmetrics.MOTAccumulator to collect associations per frame and compute common MOT metrics.
|
|
187
|
+
Distance matrix is taken directly from motmetrics.distances.iou_matrix (which returns 1 - IoU).
|
|
188
|
+
Pairs with distance > (1 - iou_threshold) are set to infinity to exclude them from matching.
|
|
189
|
+
"""
|
|
190
|
+
acc = mm.MOTAccumulator(auto_id=True)
|
|
191
|
+
|
|
192
|
+
frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
|
|
193
|
+
for f in frames:
|
|
194
|
+
gts = gt_tracks.get(f, [])
|
|
195
|
+
preds = pred_tracks.get(f, [])
|
|
196
|
+
|
|
197
|
+
gt_ids = [g["track_id"] for g in gts]
|
|
198
|
+
pred_ids = [p["track_id"] for p in preds]
|
|
199
|
+
|
|
200
|
+
if gts and preds:
|
|
201
|
+
gt_boxes = np.array([g["bbox"] for g in gts])
|
|
202
|
+
pred_boxes = np.array([p["bbox"] for p in preds])
|
|
203
|
+
|
|
204
|
+
# motmetrics provides a distance matrix (1 - IoU)
|
|
205
|
+
dist_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
|
|
206
|
+
# exclude pairs with IoU < threshold => distance > 1 - threshold
|
|
207
|
+
dist_mat = np.array(dist_mat, dtype=float)
|
|
208
|
+
dist_mat[np.isnan(dist_mat)] = np.inf
|
|
209
|
+
dist_mat[dist_mat > (1.0 - self.iou_threshold)] = np.inf
|
|
210
|
+
else:
|
|
211
|
+
dist_mat = np.full((len(gts), len(preds)), np.inf)
|
|
212
|
+
|
|
213
|
+
acc.update(gt_ids, pred_ids, dist_mat)
|
|
214
|
+
|
|
215
|
+
mh = mm.metrics.create()
|
|
216
|
+
summary = mh.compute(
|
|
217
|
+
acc,
|
|
218
|
+
metrics=[
|
|
219
|
+
"mota",
|
|
220
|
+
"motp",
|
|
221
|
+
"idf1",
|
|
222
|
+
"num_switches",
|
|
223
|
+
"num_fragmentations",
|
|
224
|
+
"num_misses",
|
|
225
|
+
"num_false_positives",
|
|
226
|
+
],
|
|
227
|
+
name="eval",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def get_val(col: str, default=0.0):
|
|
231
|
+
if summary.empty or col not in summary.columns:
|
|
232
|
+
return float(default)
|
|
233
|
+
v = summary.iloc[0][col]
|
|
234
|
+
return float(v) if not np.isnan(v) else float(default)
|
|
235
|
+
|
|
236
|
+
return {
|
|
237
|
+
"mota": get_val("mota", 0.0),
|
|
238
|
+
"motp": get_val("motp", 0.0),
|
|
239
|
+
"idf1": get_val("idf1", 0.0),
|
|
240
|
+
"id_switches": int(get_val("num_switches", 0.0)),
|
|
241
|
+
"fragmentations": int(get_val("num_fragmentations", 0.0)),
|
|
242
|
+
"num_misses": int(get_val("num_misses", 0.0)),
|
|
243
|
+
"num_false_positives": int(get_val("num_false_positives", 0.0)),
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def evaluate(
|
|
248
|
+
gt_annotation: VideoAnnotation,
|
|
249
|
+
pred_annotation: VideoAnnotation,
|
|
250
|
+
iou_threshold: float = 0.5,
|
|
251
|
+
) -> Dict[str, Union[float, int]]:
|
|
252
|
+
"""
|
|
253
|
+
Evaluate tracking predictions against ground truth.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
gt_annotation: Ground-truth annotation, an object of class supervisely VideoAnnotation containing reference object tracks.
|
|
257
|
+
pred_annotation: Predicted annotation, an object of class supervisely VideoAnnotation to be compared against the ground truth.
|
|
258
|
+
iou_threshold: Minimum Intersection-over-Union required for a detection to be considered a valid match.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
dict: json with evaluation metrics.
|
|
262
|
+
"""
|
|
263
|
+
evaluator = TrackingEvaluator(iou_threshold=iou_threshold)
|
|
264
|
+
return evaluator.evaluate(gt_annotation, pred_annotation)
|