supervisely 6.73.427__py3-none-any.whl → 6.73.429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,264 @@
1
+ import numpy as np
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Union
4
+
5
+ from scipy.optimize import linear_sum_assignment # pylint: disable=import-error
6
+
7
+ import supervisely as sly
8
+ from supervisely.video_annotation.video_annotation import VideoAnnotation
9
+
10
+ import motmetrics as mm # pylint: disable=import-error
11
+
12
+ class TrackingEvaluator:
13
+ """
14
+ Evaluator for video tracking metrics including MOTA, MOTP, IDF1.
15
+ """
16
+
17
+ def __init__(self, iou_threshold: float = 0.5):
18
+ """Initialize evaluator with IoU threshold for matching."""
19
+ from supervisely.nn.tracker import TRACKING_LIBS_INSTALLED
20
+ if not TRACKING_LIBS_INSTALLED:
21
+ raise ImportError(
22
+ "Tracking dependencies are not installed. "
23
+ "Please install supervisely with `pip install supervisely[tracking]`."
24
+ )
25
+
26
+ if not 0.0 <= iou_threshold <= 1.0:
27
+ raise ValueError("iou_threshold must be in [0.0, 1.0]")
28
+ self.iou_threshold = iou_threshold
29
+
30
+ def evaluate(
31
+ self,
32
+ gt_annotation: VideoAnnotation,
33
+ pred_annotation: VideoAnnotation,
34
+ ) -> Dict[str, Union[float, int]]:
35
+ """Main entry: extract tracks from annotations, compute basic and MOT metrics, return results."""
36
+ self._validate_annotations(gt_annotation, pred_annotation)
37
+ self.img_height, self.img_width = gt_annotation.img_size
38
+
39
+ gt_tracks = self._extract_tracks(gt_annotation)
40
+ pred_tracks = self._extract_tracks(pred_annotation)
41
+
42
+ basic = self._compute_basic_metrics(gt_tracks, pred_tracks)
43
+ mot = self._compute_mot_metrics(gt_tracks, pred_tracks)
44
+
45
+ results = {
46
+ # basic detection
47
+ "precision": basic["precision"],
48
+ "recall": basic["recall"],
49
+ "f1": basic["f1"],
50
+ "avg_iou": basic["avg_iou"],
51
+ "true_positives": basic["tp"],
52
+ "false_positives": basic["fp"],
53
+ "false_negatives": basic["fn"],
54
+ "total_gt_objects": basic["total_gt"],
55
+ "total_pred_objects": basic["total_pred"],
56
+
57
+ # motmetrics
58
+ "mota": mot["mota"],
59
+ "motp": mot["motp"],
60
+ "idf1": mot["idf1"],
61
+ "id_switches": mot["id_switches"],
62
+ "fragmentations": mot["fragmentations"],
63
+ "num_misses": mot["num_misses"],
64
+ "num_false_positives": mot["num_false_positives"],
65
+
66
+ # config
67
+ "iou_threshold": self.iou_threshold,
68
+ }
69
+ return results
70
+
71
+ def _validate_annotations(self, gt: VideoAnnotation, pred: VideoAnnotation):
72
+ """Minimal type validation for annotations."""
73
+ if not isinstance(gt, VideoAnnotation) or not isinstance(pred, VideoAnnotation):
74
+ raise TypeError("gt_annotation and pred_annotation must be VideoAnnotation instances")
75
+
76
+ def _extract_tracks(self, annotation: VideoAnnotation) -> Dict[int, List[Dict]]:
77
+ """
78
+ Extract tracks from a VideoAnnotation into a dict keyed by frame index.
79
+ Each element is a dict: {'track_id': int, 'bbox': [x1,y1,x2,y2], 'confidence': float, 'class_name': str}
80
+ """
81
+ frames_to_tracks = defaultdict(list)
82
+
83
+ for frame in annotation.frames:
84
+ frame_idx = frame.index
85
+ for figure in frame.figures:
86
+ # use track_id if present, otherwise fallback to object's key int
87
+ track_id = int(figure.track_id) if figure.track_id is not None else figure.video_object.key().int
88
+
89
+ bbox = figure.geometry
90
+ if not isinstance(bbox, sly.Rectangle):
91
+ bbox = bbox.to_bbox()
92
+
93
+ x1 = float(bbox.left)
94
+ y1 = float(bbox.top)
95
+ x2 = float(bbox.right)
96
+ y2 = float(bbox.bottom)
97
+
98
+ frames_to_tracks[frame_idx].append({
99
+ "track_id": track_id,
100
+ "bbox": [x1, y1, x2, y2],
101
+ "confidence": float(getattr(figure, "confidence", 1.0)),
102
+ "class_name": figure.video_object.obj_class.name
103
+ })
104
+
105
+ return dict(frames_to_tracks)
106
+
107
+ def _compute_basic_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
108
+ """
109
+ Compute per-frame true positives / false positives / false negatives and average IoU.
110
+ Matching is performed with Hungarian algorithm (scipy). Matches with IoU < threshold are discarded.
111
+ """
112
+ tp = fp = fn = 0
113
+ total_iou = 0.0
114
+ iou_count = 0
115
+
116
+ frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
117
+ for f in frames:
118
+ gts = gt_tracks.get(f, [])
119
+ preds = pred_tracks.get(f, [])
120
+
121
+ if not gts and not preds:
122
+ continue
123
+ if not gts:
124
+ fp += len(preds)
125
+ continue
126
+ if not preds:
127
+ fn += len(gts)
128
+ continue
129
+
130
+ gt_boxes = np.array([g["bbox"] for g in gts])
131
+ pred_boxes = np.array([p["bbox"] for p in preds])
132
+
133
+ # get cost matrix from motmetrics (cost = 1 - IoU)
134
+ cost_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
135
+ # replace NaNs (if any) with a large cost so Hungarian will avoid them
136
+ cost_for_assignment = np.where(np.isnan(cost_mat), 1e6, cost_mat)
137
+
138
+ # Hungarian assignment (minimize cost -> maximize IoU)
139
+ row_idx, col_idx = linear_sum_assignment(cost_for_assignment)
140
+
141
+ matched_gt = set()
142
+ matched_pred = set()
143
+ for r, c in zip(row_idx, col_idx):
144
+ if r < cost_mat.shape[0] and c < cost_mat.shape[1]:
145
+ # IoU = 1 - cost
146
+ cost_val = cost_mat[r, c]
147
+ if np.isnan(cost_val):
148
+ continue
149
+ iou_val = 1.0 - float(cost_val)
150
+ if iou_val >= self.iou_threshold:
151
+ matched_gt.add(r)
152
+ matched_pred.add(c)
153
+ total_iou += iou_val
154
+ iou_count += 1
155
+
156
+ frame_tp = len(matched_gt)
157
+ frame_fp = len(preds) - len(matched_pred)
158
+ frame_fn = len(gts) - len(matched_gt)
159
+
160
+ tp += frame_tp
161
+ fp += frame_fp
162
+ fn += frame_fn
163
+
164
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
165
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
166
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
167
+ avg_iou = total_iou / iou_count if iou_count > 0 else 0.0
168
+
169
+ total_gt = sum(len(v) for v in gt_tracks.values())
170
+ total_pred = sum(len(v) for v in pred_tracks.values())
171
+
172
+ return {
173
+ "precision": precision,
174
+ "recall": recall,
175
+ "f1": f1,
176
+ "avg_iou": avg_iou,
177
+ "tp": tp,
178
+ "fp": fp,
179
+ "fn": fn,
180
+ "total_gt": total_gt,
181
+ "total_pred": total_pred,
182
+ }
183
+
184
+ def _compute_mot_metrics(self, gt_tracks: Dict[int, List[Dict]], pred_tracks: Dict[int, List[Dict]]):
185
+ """
186
+ Use motmetrics.MOTAccumulator to collect associations per frame and compute common MOT metrics.
187
+ Distance matrix is taken directly from motmetrics.distances.iou_matrix (which returns 1 - IoU).
188
+ Pairs with distance > (1 - iou_threshold) are set to infinity to exclude them from matching.
189
+ """
190
+ acc = mm.MOTAccumulator(auto_id=True)
191
+
192
+ frames = sorted(set(list(gt_tracks.keys()) + list(pred_tracks.keys())))
193
+ for f in frames:
194
+ gts = gt_tracks.get(f, [])
195
+ preds = pred_tracks.get(f, [])
196
+
197
+ gt_ids = [g["track_id"] for g in gts]
198
+ pred_ids = [p["track_id"] for p in preds]
199
+
200
+ if gts and preds:
201
+ gt_boxes = np.array([g["bbox"] for g in gts])
202
+ pred_boxes = np.array([p["bbox"] for p in preds])
203
+
204
+ # motmetrics provides a distance matrix (1 - IoU)
205
+ dist_mat = mm.distances.iou_matrix(gt_boxes, pred_boxes, max_iou=1.0)
206
+ # exclude pairs with IoU < threshold => distance > 1 - threshold
207
+ dist_mat = np.array(dist_mat, dtype=float)
208
+ dist_mat[np.isnan(dist_mat)] = np.inf
209
+ dist_mat[dist_mat > (1.0 - self.iou_threshold)] = np.inf
210
+ else:
211
+ dist_mat = np.full((len(gts), len(preds)), np.inf)
212
+
213
+ acc.update(gt_ids, pred_ids, dist_mat)
214
+
215
+ mh = mm.metrics.create()
216
+ summary = mh.compute(
217
+ acc,
218
+ metrics=[
219
+ "mota",
220
+ "motp",
221
+ "idf1",
222
+ "num_switches",
223
+ "num_fragmentations",
224
+ "num_misses",
225
+ "num_false_positives",
226
+ ],
227
+ name="eval",
228
+ )
229
+
230
+ def get_val(col: str, default=0.0):
231
+ if summary.empty or col not in summary.columns:
232
+ return float(default)
233
+ v = summary.iloc[0][col]
234
+ return float(v) if not np.isnan(v) else float(default)
235
+
236
+ return {
237
+ "mota": get_val("mota", 0.0),
238
+ "motp": get_val("motp", 0.0),
239
+ "idf1": get_val("idf1", 0.0),
240
+ "id_switches": int(get_val("num_switches", 0.0)),
241
+ "fragmentations": int(get_val("num_fragmentations", 0.0)),
242
+ "num_misses": int(get_val("num_misses", 0.0)),
243
+ "num_false_positives": int(get_val("num_false_positives", 0.0)),
244
+ }
245
+
246
+
247
+ def evaluate(
248
+ gt_annotation: VideoAnnotation,
249
+ pred_annotation: VideoAnnotation,
250
+ iou_threshold: float = 0.5,
251
+ ) -> Dict[str, Union[float, int]]:
252
+ """
253
+ Evaluate tracking predictions against ground truth.
254
+
255
+ Args:
256
+ gt_annotation: Ground-truth annotation, an object of class supervisely VideoAnnotation containing reference object tracks.
257
+ pred_annotation: Predicted annotation, an object of class supervisely VideoAnnotation to be compared against the ground truth.
258
+ iou_threshold: Minimum Intersection-over-Union required for a detection to be considered a valid match.
259
+
260
+ Returns:
261
+ dict: json with evaluation metrics.
262
+ """
263
+ evaluator = TrackingEvaluator(iou_threshold=iou_threshold)
264
+ return evaluator.evaluate(gt_annotation, pred_annotation)
@@ -0,0 +1,274 @@
1
+
2
+ from typing import List, Union, Dict, Tuple
3
+ from pathlib import Path
4
+ from collections import defaultdict
5
+ import numpy as np
6
+
7
+ import supervisely as sly
8
+ from supervisely.nn.model.prediction import Prediction
9
+ from supervisely import VideoAnnotation
10
+ from supervisely import logger
11
+
12
+
13
+ def predictions_to_video_annotation(
14
+ predictions: List[Prediction],
15
+ ) -> VideoAnnotation:
16
+ """
17
+ Convert list of Prediction objects to VideoAnnotation.
18
+
19
+ Args:
20
+ predictions: List of Prediction objects, one per frame
21
+
22
+ Returns:
23
+ VideoAnnotation object with tracked objects
24
+
25
+ """
26
+
27
+ if not predictions:
28
+ raise ValueError("Empty predictions list provided")
29
+
30
+ frame_shape = predictions[0].annotation.img_size
31
+ img_h, img_w = frame_shape
32
+ video_objects = {}
33
+ frames = []
34
+
35
+ for pred in predictions:
36
+ frame_figures = []
37
+ frame_idx = pred.frame_index
38
+
39
+ # Get data using public properties
40
+ boxes = pred.boxes # Public property - np.array (N, 4) in tlbr format
41
+ classes = pred.classes # Public property - list of class names
42
+ track_ids = pred.track_ids # Public property - can be None
43
+
44
+ # Skip frame if no detections
45
+ if len(boxes) == 0:
46
+ frames.append(sly.Frame(frame_idx, []))
47
+ continue
48
+
49
+ for bbox, class_name, track_id in zip(boxes, classes, track_ids):
50
+ # Clip bbox to image boundaries
51
+ # Note: pred.boxes returns tlbr format (top, left, bottom, right)
52
+ top, left, bottom, right = bbox
53
+ dims = np.array([img_h, img_w, img_h, img_w]) - 1
54
+ top, left, bottom, right = np.clip([top, left, bottom, right], 0, dims)
55
+
56
+ # Convert to integer coordinates
57
+ top, left, bottom, right = int(top), int(left), int(bottom), int(right)
58
+
59
+ # Get or create VideoObject
60
+ if track_id not in video_objects:
61
+ # Find obj_class from prediction annotation
62
+ obj_class = None
63
+ for label in pred.annotation.labels:
64
+ if label.obj_class.name == class_name:
65
+ obj_class = label.obj_class
66
+ break
67
+
68
+ if obj_class is None:
69
+ # Create obj_class if not found (fallback)
70
+ obj_class = sly.ObjClass(class_name, sly.Rectangle)
71
+
72
+ video_objects[track_id] = sly.VideoObject(obj_class)
73
+
74
+ video_object = video_objects[track_id]
75
+ rect = sly.Rectangle(top=top, left=left, bottom=bottom, right=right)
76
+ frame_figures.append(sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id)))
77
+
78
+ frames.append(sly.Frame(frame_idx, frame_figures))
79
+
80
+ objects = list(video_objects.values())
81
+
82
+ return VideoAnnotation(
83
+ img_size=frame_shape,
84
+ frames_count=len(predictions),
85
+ objects=sly.VideoObjectCollection(objects),
86
+ frames=sly.FrameCollection(frames)
87
+ )
88
+
89
+ def video_annotation_to_mot(
90
+ annotation: VideoAnnotation,
91
+ output_path: Union[str, Path] = None,
92
+ class_to_id_mapping: Dict[str, int] = None
93
+ ) -> Union[str, List[str]]:
94
+ """
95
+ Convert Supervisely VideoAnnotation to MOT format.
96
+ MOT format: frame_id,track_id,left,top,width,height,confidence,class_id,visibility
97
+ """
98
+ mot_lines = []
99
+
100
+ # Create default class mapping if not provided
101
+ if class_to_id_mapping is None:
102
+ unique_classes = set()
103
+ for frame in annotation.frames:
104
+ for figure in frame.figures:
105
+ unique_classes.add(figure.video_object.obj_class.name)
106
+ class_to_id_mapping = {cls_name: idx + 1 for idx, cls_name in enumerate(sorted(unique_classes))}
107
+
108
+ # Extract tracks
109
+ for frame in annotation.frames:
110
+ frame_id = frame.index + 1 # MOT uses 1-based frame indexing
111
+
112
+ for figure in frame.figures:
113
+ # Get track ID from VideoFigure.track_id (official API)
114
+ if figure.track_id is not None:
115
+ track_id = int(figure.track_id)
116
+ else:
117
+ track_id = figure.video_object.key().int
118
+
119
+ # Get bounding box
120
+ if isinstance(figure.geometry, sly.Rectangle):
121
+ bbox = figure.geometry
122
+ else:
123
+ bbox = figure.geometry.to_bbox()
124
+
125
+ left = bbox.left
126
+ top = bbox.top
127
+ width = bbox.width
128
+ height = bbox.height
129
+
130
+ # Get class ID
131
+ class_name = figure.video_object.obj_class.name
132
+ class_id = class_to_id_mapping.get(class_name, 1)
133
+
134
+ # Get confidence (default)
135
+ confidence = 1.0
136
+
137
+ # Visibility (assume visible)
138
+ visibility = 1
139
+
140
+ # Create MOT line
141
+ mot_line = f"{frame_id},{track_id},{left:.2f},{top:.2f},{width:.2f},{height:.2f},{confidence:.3f},{class_id},{visibility}"
142
+ mot_lines.append(mot_line)
143
+
144
+ # Save to file if path provided
145
+ if output_path:
146
+ output_path = Path(output_path)
147
+ output_path.parent.mkdir(parents=True, exist_ok=True)
148
+
149
+ with open(output_path, 'w') as f:
150
+ for line in mot_lines:
151
+ f.write(line + '\n')
152
+
153
+ logger.info(f"Saved MOT format to: {output_path} ({len(mot_lines)} detections)")
154
+ return str(output_path)
155
+
156
+ return mot_lines
157
+
158
+ def mot_to_video_annotation(
159
+ mot_file_path: Union[str, Path],
160
+ img_size: Tuple[int, int] = (1080, 1920),
161
+ class_mapping: Dict[int, str] = None,
162
+ default_class_name: str = "person"
163
+ ) -> VideoAnnotation:
164
+ """
165
+ Convert MOT format tracking data to Supervisely VideoAnnotation.
166
+ MOT format: frame_id,track_id,left,top,width,height,confidence,class_id,visibility
167
+ """
168
+ mot_file_path = Path(mot_file_path)
169
+
170
+ if not mot_file_path.exists():
171
+ raise FileNotFoundError(f"MOT file not found: {mot_file_path}")
172
+
173
+ logger.info(f"Loading MOT data from: {mot_file_path}")
174
+ logger.info(f"Image size: {img_size} (height, width)")
175
+
176
+ # Default class mapping
177
+ if class_mapping is None:
178
+ class_mapping = {1: default_class_name}
179
+
180
+ # Parse MOT file
181
+ video_objects = {} # track_id -> VideoObject
182
+ frames_data = defaultdict(list) # frame_idx -> list of figures
183
+ max_frame_idx = 0
184
+ img_h, img_w = img_size
185
+
186
+ with open(mot_file_path, 'r') as f:
187
+ for line_num, line in enumerate(f, 1):
188
+ line = line.strip()
189
+ if not line or line.startswith('#'):
190
+ continue
191
+
192
+ try:
193
+ parts = line.split(',')
194
+ if len(parts) < 6: # Minimum required fields
195
+ continue
196
+
197
+ frame_id = int(parts[0])
198
+ track_id = int(parts[1])
199
+ left = float(parts[2])
200
+ top = float(parts[3])
201
+ width = float(parts[4])
202
+ height = float(parts[5])
203
+
204
+ # Optional fields
205
+ confidence = float(parts[6]) if len(parts) > 6 and parts[6] != '-1' else 1.0
206
+ class_id = int(parts[7]) if len(parts) > 7 and parts[7] != '-1' else 1
207
+ visibility = float(parts[8]) if len(parts) > 8 and parts[8] != '-1' else 1.0
208
+
209
+ frame_idx = frame_id - 1 # Convert to 0-based indexing
210
+ max_frame_idx = max(max_frame_idx, frame_idx)
211
+
212
+ # Skip low confidence detections
213
+ if confidence < 0.1:
214
+ continue
215
+
216
+ # Calculate coordinates with safer clipping
217
+ right = left + width
218
+ bottom = top + height
219
+
220
+ # Clip to image boundaries
221
+ left = max(0, int(left))
222
+ top = max(0, int(top))
223
+ right = min(int(right), img_w - 1)
224
+ bottom = min(int(bottom), img_h - 1)
225
+
226
+ # Skip invalid boxes
227
+ if right <= left or bottom <= top:
228
+ continue
229
+
230
+ # Get class name
231
+ class_name = class_mapping.get(class_id, default_class_name)
232
+
233
+ # Create VideoObject if not exists
234
+ if track_id not in video_objects:
235
+ obj_class = sly.ObjClass(class_name, sly.Rectangle)
236
+ video_objects[track_id] = sly.VideoObject(obj_class)
237
+
238
+ video_object = video_objects[track_id]
239
+
240
+ # Create rectangle and figure with track_id
241
+ rect = sly.Rectangle(top=top, left=left, bottom=bottom, right=right)
242
+ figure = sly.VideoFigure(video_object, rect, frame_idx, track_id=str(track_id))
243
+
244
+ frames_data[frame_idx].append(figure)
245
+
246
+ except (ValueError, IndexError) as e:
247
+ logger.warning(f"Skipped invalid MOT line {line_num}: {line} - {e}")
248
+ continue
249
+
250
+ # Create frames
251
+ frames = []
252
+ if frames_data:
253
+ frames_count = max(frames_data.keys()) + 1
254
+
255
+ for frame_idx in range(frames_count):
256
+ figures = frames_data.get(frame_idx, [])
257
+ frames.append(sly.Frame(frame_idx, figures))
258
+ else:
259
+ frames_count = 1
260
+ frames = [sly.Frame(0, [])]
261
+
262
+ # Create VideoAnnotation
263
+ objects = list(video_objects.values())
264
+
265
+ annotation = VideoAnnotation(
266
+ img_size=img_size,
267
+ frames_count=frames_count,
268
+ objects=sly.VideoObjectCollection(objects),
269
+ frames=sly.FrameCollection(frames)
270
+ )
271
+
272
+ logger.info(f"Created VideoAnnotation with {len(objects)} tracks and {frames_count} frames")
273
+
274
+ return annotation