supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ import supervisely as sly
2
+ from supervisely.nn.tracker.base_tracker import BaseTracker
3
+ from supervisely import Annotation, VideoAnnotation
4
+ from supervisely.annotation.label import LabelingStatus
5
+ from dataclasses import dataclass
6
+ from types import SimpleNamespace
7
+ from typing import List, Dict, Tuple, Any, Optional
8
+ import numpy as np
9
+ import yaml
10
+ import os
11
+ from pathlib import Path
12
+ from supervisely import logger
13
+ from supervisely.nn.tracker.botsort.tracker.mc_bot_sort import BoTSORT
14
+
15
+
16
+ @dataclass
17
+ class TrackedObject:
18
+ """
19
+ Data class representing a tracked object in a single frame.
20
+
21
+ Args:
22
+ track_id: Unique identifier for the track
23
+ det_id: Detection ID for mapping back to original annotation
24
+ bbox: Bounding box coordinates in format [x1, y1, x2, y2]
25
+ class_name: String class name
26
+ class_sly_id: Supervisely class ID (from ObjClass.sly_id)
27
+ score: Confidence score of the detection/track
28
+ """
29
+ track_id: int
30
+ det_id: int
31
+ bbox: List[float] # [x1, y1, x2, y2]
32
+ class_name: str
33
+ class_sly_id: Optional[int] # Supervisely class ID
34
+ score: float
35
+
36
+
37
+ class BotSortTracker(BaseTracker):
38
+
39
+ def __init__(self, settings: dict = None, device: str = None):
40
+ super().__init__(settings=settings, device=device)
41
+
42
+ from supervisely.nn.tracker import TRACKING_LIBS_INSTALLED
43
+ if not TRACKING_LIBS_INSTALLED:
44
+ raise ImportError(
45
+ "Tracking dependencies are not installed. "
46
+ "Please install supervisely with `pip install supervisely[tracking]`."
47
+ )
48
+
49
+ # Load default settings from YAML file
50
+ self.settings = self._load_default_settings()
51
+
52
+ # Override with user settings if provided
53
+ if settings:
54
+ self.settings.update(settings)
55
+
56
+ args = SimpleNamespace(**self.settings)
57
+ args.name = "BotSORT"
58
+ args.device = self.device
59
+
60
+ self.tracker = BoTSORT(args=args)
61
+
62
+ # State for accumulating results
63
+ self.frame_tracks = []
64
+ self.obj_classes = {} # class_id -> ObjClass
65
+ self.current_frame = 0
66
+ self.class_ids = {} # class_name -> class_id mapping
67
+ self.frame_shape = ()
68
+
69
+ def _load_default_settings(self) -> dict:
70
+ """Internal method: calls classmethod"""
71
+ return self.get_default_params()
72
+
73
+ def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
74
+ """Update tracker and return list of matches for current frame."""
75
+ self.frame_shape = frame.shape[:2]
76
+ self._update_obj_classes(annotation)
77
+ detections = self._convert_annotation(annotation)
78
+ output_stracks, detection_track_map = self.tracker.update(detections, frame)
79
+ tracks = self._stracks_to_tracks(output_stracks, detection_track_map)
80
+
81
+ # Store tracks for VideoAnnotation creation
82
+ self.frame_tracks.append(tracks)
83
+ self.current_frame += 1
84
+
85
+ matches = []
86
+ for pair in detection_track_map:
87
+ det_id = pair["det_id"]
88
+ track_id = pair["track_id"]
89
+
90
+ if track_id is not None:
91
+ match = {
92
+ "track_id": track_id,
93
+ "label": annotation.labels[det_id]
94
+ }
95
+ matches.append(match)
96
+
97
+ return matches
98
+
99
+ def reset(self) -> None:
100
+ super().reset()
101
+ self.frame_tracks = []
102
+ self.obj_classes = {}
103
+ self.current_frame = 0
104
+ self.class_ids = {}
105
+ self.frame_shape = ()
106
+
107
+ def track(self, frames: List[np.ndarray], annotations: List[Annotation]) -> VideoAnnotation:
108
+ """Track objects through sequence of frames and return VideoAnnotation."""
109
+ if len(frames) != len(annotations):
110
+ raise ValueError("Number of frames and annotations must match")
111
+
112
+ self.reset()
113
+
114
+ # Process each frame
115
+ for frame_idx, (frame, annotation) in enumerate(zip(frames, annotations)):
116
+ self.current_frame = frame_idx
117
+ self.update(frame, annotation)
118
+
119
+ # Convert accumulated tracks to VideoAnnotation
120
+ return self._create_video_annotation()
121
+
122
+ def _convert_annotation(self, annotation: Annotation) -> np.ndarray:
123
+ """Convert Supervisely annotation to BoTSORT detection format."""
124
+ detections_list = []
125
+
126
+ for label in annotation.labels:
127
+ if label.tags.get("confidence", None) is not None:
128
+ confidence = label.tags.get("confidence").value
129
+ elif label.tags.get("conf", None) is not None:
130
+ confidence = label.tags.get("conf").value
131
+ else:
132
+ confidence = 1.0
133
+ logger.debug(
134
+ f"Label {label.obj_class.name} does not have confidence tag, using default value 1.0"
135
+ )
136
+
137
+ rectangle = label.geometry.to_bbox()
138
+
139
+ class_name = label.obj_class.name
140
+ class_id = self.class_ids[class_name]
141
+
142
+ detection = [
143
+ rectangle.left, # x1
144
+ rectangle.top, # y1
145
+ rectangle.right, # x2
146
+ rectangle.bottom, # y2
147
+ confidence, # score
148
+ class_id, # class_id as number
149
+ ]
150
+ detections_list.append(detection)
151
+
152
+ if detections_list:
153
+ return np.array(detections_list, dtype=np.float32)
154
+ else:
155
+ return np.zeros((0, 6), dtype=np.float32)
156
+
157
+ def _stracks_to_tracks(self, output_stracks, detection_track_map) -> List[TrackedObject]:
158
+ """Convert BoTSORT output tracks to TrackedObject dataclass instances."""
159
+ tracks = []
160
+
161
+ id_to_name = {v: k for k, v in self.class_ids.items()}
162
+
163
+ track_id_to_det_id = {}
164
+ for pair in detection_track_map:
165
+ det_id = pair["det_id"]
166
+ track_id = pair["track_id"]
167
+ track_id_to_det_id[track_id] = det_id
168
+
169
+ for strack in output_stracks:
170
+ # BoTSORT may store class info in different attributes
171
+ # Try to get class_id from various possible sources
172
+ class_id = 0 # default
173
+
174
+ if hasattr(strack, 'cls') and strack.cls != -1:
175
+ # cls should contain the numeric ID we passed in
176
+ class_id = int(strack.cls)
177
+ elif hasattr(strack, 'class_id'):
178
+ class_id = int(strack.class_id)
179
+
180
+ class_name = id_to_name.get(class_id, "unknown")
181
+
182
+ # Get Supervisely class ID from stored ObjClass
183
+ class_sly_id = None
184
+ if class_name in self.obj_classes:
185
+ obj_class = self.obj_classes[class_name]
186
+ class_sly_id = obj_class.sly_id
187
+
188
+ track = TrackedObject(
189
+ track_id=strack.track_id,
190
+ det_id=track_id_to_det_id.get(strack.track_id),
191
+ bbox=strack.tlbr.tolist(), # [x1, y1, x2, y2]
192
+ class_name=class_name,
193
+ class_sly_id=class_sly_id,
194
+ score=getattr(strack, 'score', 1.0)
195
+ )
196
+ tracks.append(track)
197
+
198
+ return tracks
199
+
200
+ def _update_obj_classes(self, annotation: Annotation):
201
+ """Extract and store object classes from annotation."""
202
+ for label in annotation.labels:
203
+ class_name = label.obj_class.name
204
+ if class_name not in self.obj_classes:
205
+ self.obj_classes[class_name] = label.obj_class
206
+
207
+ if class_name not in self.class_ids:
208
+ self.class_ids[class_name] = len(self.class_ids)
209
+
210
+
211
+ def _create_video_annotation(self) -> VideoAnnotation:
212
+ """Convert accumulated tracking results to Supervisely VideoAnnotation."""
213
+ img_h, img_w = self.frame_shape
214
+ video_objects = {} # track_id -> VideoObject
215
+ frames = []
216
+
217
+ for frame_idx, tracks in enumerate(self.frame_tracks):
218
+ frame_figures = []
219
+
220
+ for track in tracks:
221
+ track_id = track.track_id
222
+ bbox = track.bbox # [x1, y1, x2, y2]
223
+ class_name = track.class_name
224
+
225
+ # Clip bbox to image boundaries
226
+ x1, y1, x2, y2 = bbox
227
+ dims = np.array([img_w, img_h, img_w, img_h]) - 1
228
+ x1, y1, x2, y2 = np.clip([x1, y1, x2, y2], 0, dims)
229
+
230
+ # Get or create VideoObject
231
+ if track_id not in video_objects:
232
+ obj_class = self.obj_classes.get(class_name)
233
+ if obj_class is None:
234
+ continue # Skip if class not found
235
+ video_objects[track_id] = sly.VideoObject(obj_class)
236
+
237
+ video_object = video_objects[track_id]
238
+ rect = sly.Rectangle(top=y1, left=x1, bottom=y2, right=x2)
239
+ frame_figures.append(sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id), status=LabelingStatus.AUTO))
240
+
241
+ frames.append(sly.Frame(frame_idx, frame_figures))
242
+
243
+ objects = list(video_objects.values())
244
+
245
+
246
+ return VideoAnnotation(
247
+ img_size=self.frame_shape,
248
+ frames_count=len(self.frame_tracks),
249
+ objects=sly.VideoObjectCollection(objects),
250
+ frames=sly.FrameCollection(frames)
251
+ )
252
+
253
+ @property
254
+ def video_annotation(self) -> VideoAnnotation:
255
+ """Return the accumulated VideoAnnotation."""
256
+ if not self.frame_tracks:
257
+ error_msg = (
258
+ "No tracking data available. "
259
+ "Please run tracking first using track() method or process frames with update()."
260
+ )
261
+ raise ValueError(error_msg)
262
+
263
+ return self._create_video_annotation()
264
+
265
+ @classmethod
266
+ def get_default_params(cls) -> Dict[str, Any]:
267
+ """Public API: get default params WITHOUT creating instance."""
268
+ current_dir = Path(__file__).parent
269
+ config_path = current_dir / "botsort/botsort_config.yaml"
270
+
271
+ with open(config_path, 'r', encoding='utf-8') as file:
272
+ return yaml.safe_load(file)
273
+
@@ -0,0 +1,264 @@
1
+ import numpy as np
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Union
4
+
5
+ from scipy.optimize import linear_sum_assignment # pylint: disable=import-error
6
+
7
+ import supervisely as sly
8
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
9
+
10
+ import motmetrics as mm # pylint: disable=import-error
11
+
12
+ class TrackingEvaluator:
13
+ """
14
+ Evaluator for video tracking metrics including MOTA, MOTP, IDF1.
15
+ """
16
+
17
+ def __init__(self, iou_threshold: float = 0.5):
18
+ """Initialize evaluator with IoU threshold for matching."""
19
+ from supervisely.nn.tracker import TRACKING_LIBS_INSTALLED
20
+ if not TRACKING_LIBS_INSTALLED:
21
+ raise ImportError(
22
+ "Tracking dependencies are not installed. "
23
+ "Please install supervisely with `pip install supervisely[tracking]`."
24
+ )
25
+
26
+ if not 0.0 <= iou_threshold <= 1.0:
27
+ raise ValueError("iou_threshold must be in [0.0, 1.0]")
28
+ self.iou_threshold = iou_threshold
29
+
30
+ def evaluate(
31
+ self,
32
+ gt_annotation: VideoAnnotation,
33
+ pred_annotation: VideoAnnotation,
34
+ ) -> Dict[str, Union[float, int]]:
35
+ """Main entry: extract tracks from annotations, compute basic and MOT metrics, return results."""
36
+ self._validate_annotations(gt_annotation, pred_annotation)
37
+ self.img_height, self.img_width = gt_annotation.img_size
38
+
39
+ gt_tracks = self._extract_tracks(gt_annotation)
40
+ pred_tracks = self._extract_tracks(pred_annotation)
41
+
42
+ basic = self._compute_basic_metrics(gt_tracks, pred_tracks)
43
+ mot = self._compute_mot_metrics(gt_tracks, pred_tracks)
44
+
45
+ results = {
46
+ # basic detection
47
+ "precision": basic["precision"],
48
+ "recall": basic["recall"],
49
+ "f1": basic["f1"],
50
+ "avg_iou": basic["avg_iou"],
51
+ "true_positives": basic["tp"],
52
+ "false_positives": basic["fp"],
53
+ "false_negatives": basic["fn"],
54
+ "total_gt_objects": basic["total_gt"],
55
+ "total_pred_objects": basic["total_pred"],
56
+
57
+ # motmetrics
58
+ "mota": mot["mota"],
59
+ "motp": mot["motp"],
60
+ "idf1": mot["idf1"],
61
+ "id_switches": mot["id_switches"],
62
+ "fragmentations": mot["fragmentations"],
63
+ "num_misses": mot["num_misses"],
64
+ "num_false_positives": mot["num_false_positives"],
65
+
66
+ # config
67
+ "iou_threshold": self.iou_threshold,
68
+ }
69
+ return results
70
+
71
+ def _validate_annotations(self, gt: VideoAnnotation, pred: VideoAnnotation):
72
+ """Minimal type validation for annotations."""
73
+ if not isinstance(gt, VideoAnnotation) or not isinstance(pred, VideoAnnotation):
74
+ raise TypeError("gt_annotation and pred_annotation must be VideoAnnotation instances")
75
+
76
+ def _extract_tracks(self, annotation: VideoAnnotation) -> Dict[int, List[Dict]]:
77
+ """
78
+ Extract tracks from a VideoAnnotation into a dict keyed by frame index.
79
+ Each element is a dict: {'track_id': int, 'bbox': [x1,y1,x2,y2], 'confidence': float, 'class_name': str}
80
+ """
81
+ frames_to_tracks = defaultdict(list)
82
+
83
+ for frame in annotation.frames:
84
+ frame_idx = frame.index
85
+ for figure in frame.figures:
86
+ # use track_id if present, otherwise fallback to object's key int
87
+ track_id = int(figure.track_id) if figure.track_id is not None else figure.video_object.key().int
88
+
89
+ bbox = figure.geometry
90
+ if not isinstance(bbox, sly.Rectangle):
91
+ bbox = bbox.to_bbox()
92
+
93
+ x1 = float(bbox.left)
94
+ y1 = float(bbox.top)
95
+ x2 = float(bbox.right)
96
+ y2 = float(bbox.bottom)
97
+
98
+ frames_to_tracks[frame_idx].append({
99
+ "track_id": track_id,
100
+ "bbox": [x1, y1, x2, y2],
101
+ "confidence": float(getattr(figure, "confidence", 1.0)),
102
+ "class_name": figure.video_object.obj_class.name
103
+ })
104
+
105
+ return dict(frames_to_tracks)
106
+
107
+ def _compute_basic_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
108
+ """
109
+ Compute per-frame true positives / false positives / false negatives and average IoU.
110
+ Matching is performed with Hungarian algorithm (scipy). Matches with IoU < threshold are discarded.
111
+ """
112
+ tp = fp = fn = 0
113
+ total_iou = 0.0
114
+ iou_count = 0
115
+
116
+ frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
117
+ for f in frames:
118
+ gts = gt_tracks.get(f, [])
119
+ preds = pred_tracks.get(f, [])
120
+
121
+ if not gts and not preds:
122
+ continue
123
+ if not gts:
124
+ fp += len(preds)
125
+ continue
126
+ if not preds:
127
+ fn += len(gts)
128
+ continue
129
+
130
+ gt_boxes = np.array([g["bbox"] for g in gts])
131
+ pred_boxes = np.array([p["bbox"] for p in preds])
132
+
133
+ # get cost matrix from motmetrics (cost = 1 - IoU)
134
+ cost_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
135
+ # replace NaNs (if any) with a large cost so Hungarian will avoid them
136
+ cost_for_assignment = np.where(np.isnan(cost_mat), 1e6, cost_mat)
137
+
138
+ # Hungarian assignment (minimize cost -> maximize IoU)
139
+ row_idx, col_idx = linear_sum_assignment(cost_for_assignment)
140
+
141
+ matched_gt = set()
142
+ matched_pred = set()
143
+ for r, c in zip(row_idx, col_idx):
144
+ if r < cost_mat.shape[0] and c < cost_mat.shape[1]:
145
+ # IoU = 1 - cost
146
+ cost_val = cost_mat[r, c]
147
+ if np.isnan(cost_val):
148
+ continue
149
+ iou_val = 1.0 - float(cost_val)
150
+ if iou_val >= self.iou_threshold:
151
+ matched_gt.add(r)
152
+ matched_pred.add(c)
153
+ total_iou += iou_val
154
+ iou_count += 1
155
+
156
+ frame_tp = len(matched_gt)
157
+ frame_fp = len(preds) - len(matched_pred)
158
+ frame_fn = len(gts) - len(matched_gt)
159
+
160
+ tp += frame_tp
161
+ fp += frame_fp
162
+ fn += frame_fn
163
+
164
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
165
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
166
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
167
+ avg_iou = total_iou / iou_count if iou_count > 0 else 0.0
168
+
169
+ total_gt = sum(len(v) for v in gt_tracks.values())
170
+ total_pred = sum(len(v) for v in pred_tracks.values())
171
+
172
+ return {
173
+ "precision": precision,
174
+ "recall": recall,
175
+ "f1": f1,
176
+ "avg_iou": avg_iou,
177
+ "tp": tp,
178
+ "fp": fp,
179
+ "fn": fn,
180
+ "total_gt": total_gt,
181
+ "total_pred": total_pred,
182
+ }
183
+
184
+ def _compute_mot_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
185
+ """
186
+ Use motmetrics.MOTAccumulator to collect associations per frame and compute common MOT metrics.
187
+ Distance matrix is taken directly from motmetrics.distances.iou_matrix (which returns 1 - IoU).
188
+ Pairs with distance > (1 - iou_threshold) are set to infinity to exclude them from matching.
189
+ """
190
+ acc = mm.MOTAccumulator(auto_id=True)
191
+
192
+ frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
193
+ for f in frames:
194
+ gts = gt_tracks.get(f, [])
195
+ preds = pred_tracks.get(f, [])
196
+
197
+ gt_ids = [g["track_id"] for g in gts]
198
+ pred_ids = [p["track_id"] for p in preds]
199
+
200
+ if gts and preds:
201
+ gt_boxes = np.array([g["bbox"] for g in gts])
202
+ pred_boxes = np.array([p["bbox"] for p in preds])
203
+
204
+ # motmetrics provides a distance matrix (1 - IoU)
205
+ dist_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
206
+ # exclude pairs with IoU < threshold => distance > 1 - threshold
207
+ dist_mat = np.array(dist_mat, dtype=float)
208
+ dist_mat[np.isnan(dist_mat)] = np.inf
209
+ dist_mat[dist_mat > (1.0 - self.iou_threshold)] = np.inf
210
+ else:
211
+ dist_mat = np.full((len(gts), len(preds)), np.inf)
212
+
213
+ acc.update(gt_ids, pred_ids, dist_mat)
214
+
215
+ mh = mm.metrics.create()
216
+ summary = mh.compute(
217
+ acc,
218
+ metrics=[
219
+ "mota",
220
+ "motp",
221
+ "idf1",
222
+ "num_switches",
223
+ "num_fragmentations",
224
+ "num_misses",
225
+ "num_false_positives",
226
+ ],
227
+ name="eval",
228
+ )
229
+
230
+ def get_val(col: str, default=0.0):
231
+ if summary.empty or col not in summary.columns:
232
+ return float(default)
233
+ v = summary.iloc[0][col]
234
+ return float(v) if not np.isnan(v) else float(default)
235
+
236
+ return {
237
+ "mota": get_val("mota", 0.0),
238
+ "motp": get_val("motp", 0.0),
239
+ "idf1": get_val("idf1", 0.0),
240
+ "id_switches": int(get_val("num_switches", 0.0)),
241
+ "fragmentations": int(get_val("num_fragmentations", 0.0)),
242
+ "num_misses": int(get_val("num_misses", 0.0)),
243
+ "num_false_positives": int(get_val("num_false_positives", 0.0)),
244
+ }
245
+
246
+
247
+ def evaluate(
248
+ gt_annotation: VideoAnnotation,
249
+ pred_annotation: VideoAnnotation,
250
+ iou_threshold: float = 0.5,
251
+ ) -> Dict[str, Union[float, int]]:
252
+ """
253
+ Evaluate tracking predictions against ground truth.
254
+
255
+ Args:
256
+ gt_annotation: Ground-truth annotation, an object of class supervisely VideoAnnotation containing reference object tracks.
257
+ pred_annotation: Predicted annotation, an object of class supervisely VideoAnnotation to be compared against the ground truth.
258
+ iou_threshold: Minimum Intersection-over-Union required for a detection to be considered a valid match.
259
+
260
+ Returns:
261
+ dict: json with evaluation metrics.
262
+ """
263
+ evaluator = TrackingEvaluator(iou_threshold=iou_threshold)
264
+ return evaluator.evaluate(gt_annotation, pred_annotation)