eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+ from eye.config import ORIENTED_BOX_COORDINATES
10
+ from eye.metrics.core import MetricTarget
11
+
12
+ if TYPE_CHECKING:
13
+ from eye.detection.core import Detections
14
+
15
+ SIZE_THRESHOLDS = (32**2, 96**2)
16
+
17
+
18
+ class ObjectSizeCategory(Enum):
19
+ ANY = -1
20
+ SMALL = 1
21
+ MEDIUM = 2
22
+ LARGE = 3
23
+
24
+
25
+ def get_object_size_category(
26
+ data: npt.NDArray, metric_target: MetricTarget
27
+ ) -> npt.NDArray[np.int_]:
28
+ """
29
+ Get the size category of an object. Distinguish based on the metric target.
30
+
31
+ Args:
32
+ data (np.ndarray): The object data, shaped (N, ...).
33
+ metric_target (MetricTarget): Determines whether boxes, masks or
34
+ oriented bounding boxes are used.
35
+
36
+ Returns:
37
+ (np.ndarray) The size category of each object, matching
38
+ the enum values of ObjectSizeCategory. Shaped (N,).
39
+ """
40
+ if metric_target == MetricTarget.BOXES:
41
+ return get_bbox_size_category(data)
42
+ if metric_target == MetricTarget.MASKS:
43
+ return get_mask_size_category(data)
44
+ if metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
45
+ return get_obb_size_category(data)
46
+ raise ValueError("Invalid metric type")
47
+
48
+
49
+ def get_bbox_size_category(xyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.int_]:
50
+ """
51
+ Get the size category of a bounding boxes array.
52
+
53
+ Args:
54
+ xyxy (np.ndarray): The bounding boxes array shaped (N, 4).
55
+
56
+ Returns:
57
+ (np.ndarray) The size category of each bounding box, matching
58
+ the enum values of ObjectSizeCategory. Shaped (N,).
59
+ """
60
+ if len(xyxy.shape) != 2 or xyxy.shape[1] != 4:
61
+ raise ValueError("Bounding boxes must be shaped (N, 4)")
62
+
63
+ width = xyxy[:, 2] - xyxy[:, 0]
64
+ height = xyxy[:, 3] - xyxy[:, 1]
65
+ areas = width * height
66
+
67
+ result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
68
+ SM, LG = SIZE_THRESHOLDS
69
+ result[areas < SM] = ObjectSizeCategory.SMALL.value
70
+ result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
71
+ result[areas >= LG] = ObjectSizeCategory.LARGE.value
72
+ return result
73
+
74
+
75
+ def get_mask_size_category(mask: npt.NDArray[np.bool_]) -> npt.NDArray[np.int_]:
76
+ """
77
+ Get the size category of detection masks.
78
+
79
+ Args:
80
+ mask (np.ndarray): The mask array shaped (N, H, W).
81
+
82
+ Returns:
83
+ (np.ndarray) The size category of each mask, matching
84
+ the enum values of ObjectSizeCategory. Shaped (N,).
85
+ """
86
+ if len(mask.shape) != 3:
87
+ raise ValueError("Masks must be shaped (N, H, W)")
88
+
89
+ areas = np.sum(mask, axis=(1, 2))
90
+
91
+ result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
92
+ SM, LG = SIZE_THRESHOLDS
93
+ result[areas < SM] = ObjectSizeCategory.SMALL.value
94
+ result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
95
+ result[areas >= LG] = ObjectSizeCategory.LARGE.value
96
+ return result
97
+
98
+
99
+ def get_obb_size_category(xyxyxyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.int_]:
100
+ """
101
+ Get the size category of a oriented bounding boxes array.
102
+
103
+ Args:
104
+ xyxyxyxy (np.ndarray): The bounding boxes array shaped (N, 4, 2).
105
+
106
+ Returns:
107
+ (np.ndarray) The size category of each bounding box, matching
108
+ the enum values of ObjectSizeCategory. Shaped (N,).
109
+ """
110
+ if len(xyxyxyxy.shape) != 3 or xyxyxyxy.shape[1] != 4 or xyxyxyxy.shape[2] != 2:
111
+ raise ValueError("Oriented bounding boxes must be shaped (N, 4, 2)")
112
+
113
+ # Shoelace formula
114
+ x = xyxyxyxy[:, :, 0]
115
+ y = xyxyxyxy[:, :, 1]
116
+ x1, x2, x3, x4 = x.T
117
+ y1, y2, y3, y4 = y.T
118
+ areas = 0.5 * np.abs(
119
+ (x1 * y2 + x2 * y3 + x3 * y4 + x4 * y1)
120
+ - (x2 * y1 + x3 * y2 + x4 * y3 + x1 * y4)
121
+ )
122
+
123
+ result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
124
+ SM, LG = SIZE_THRESHOLDS
125
+ result[areas < SM] = ObjectSizeCategory.SMALL.value
126
+ result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
127
+ result[areas >= LG] = ObjectSizeCategory.LARGE.value
128
+ return result
129
+
130
+
131
+ def get_detection_size_category(
132
+ detections: Detections, metric_target: MetricTarget = MetricTarget.BOXES
133
+ ) -> npt.NDArray[np.int_]:
134
+ """
135
+ Get the size category of a detections object.
136
+
137
+ Args:
138
+ xyxyxyxy (np.ndarray): The bounding boxes array shaped (N, 8).
139
+ metric_target (MetricTarget): Determines whether boxes, masks or
140
+ oriented bounding boxes are used.
141
+
142
+ Returns:
143
+ (np.ndarray) The size category of each bounding box, matching
144
+ the enum values of ObjectSizeCategory. Shaped (N,).
145
+ """
146
+ if metric_target == MetricTarget.BOXES:
147
+ return get_bbox_size_category(detections.xyxy)
148
+ if metric_target == MetricTarget.MASKS:
149
+ if detections.mask is None:
150
+ raise ValueError("Detections mask is not available")
151
+ return get_mask_size_category(detections.mask)
152
+ if metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
153
+ if detections.data.get(ORIENTED_BOX_COORDINATES) is None:
154
+ raise ValueError("Detections oriented bounding boxes are not available")
155
+ return get_obb_size_category(
156
+ np.array(detections.data[ORIENTED_BOX_COORDINATES])
157
+ )
158
+ raise ValueError("Invalid metric type")
@@ -0,0 +1,9 @@
1
+ def ensure_pandas_installed():
2
+ try:
3
+ import pandas # noqa
4
+ except ImportError:
5
+ raise ImportError(
6
+ "`metrics` extra is required to run the function."
7
+ " Run `pip install 'eye[metrics]'` or"
8
+ " `poetry add eye -E metrics`"
9
+ )
eye/py.typed ADDED
File without changes
eye/quick.py ADDED
@@ -0,0 +1,104 @@
1
+ """Quick helper functions for one-line operations."""
2
+
3
+ import numpy as np
4
+ from typing import Optional, Any
5
+ from .core.detections import Detections
6
+ from .core.tracking import Tracker, TrackerType
7
+ from .annotators.box import BoxAnnotator
8
+ from .annotators.label import LabelAnnotator
9
+
10
+
11
+ def detect(model_results: Any, model_type: str = "yolo") -> Detections:
12
+ """Convert any model output to Detections (one-liner).
13
+
14
+ Args:
15
+ model_results: Model output
16
+ model_type: "yolo", "tensorflow", "pytorch", "opencv"
17
+
18
+ Returns:
19
+ Detections
20
+
21
+ Example:
22
+ >>> detections = eye.detect(yolo_results)
23
+ >>> detections = eye.detect(tf_output, "tensorflow")
24
+ """
25
+ if model_type == "yolo":
26
+ return Detections.from_yolo(model_results)
27
+ elif model_type == "tensorflow":
28
+ # Assume tuple: (boxes, scores, classes)
29
+ return Detections.from_tensorflow(*model_results)
30
+ elif model_type == "pytorch":
31
+ return Detections.from_pytorch(model_results)
32
+ elif model_type == "opencv":
33
+ return Detections.from_opencv(model_results)
34
+ else:
35
+ raise ValueError(f"Unknown model_type: {model_type}")
36
+
37
+
38
+ def track(
39
+ detections: Detections,
40
+ tracker: Optional[Tracker] = None,
41
+ use_case: str = "traffic"
42
+ ) -> Detections:
43
+ """Track detections (one-liner with auto-setup).
44
+
45
+ Args:
46
+ detections: Input detections
47
+ tracker: Existing tracker (or None to create)
48
+ use_case: "traffic", "pedestrians", "sports", etc.
49
+
50
+ Returns:
51
+ Tracked detections
52
+
53
+ Example:
54
+ >>> tracked = eye.track(detections, use_case="traffic")
55
+ """
56
+ if tracker is None:
57
+ # Create default tracker
58
+ if use_case == "pedestrians":
59
+ tracker = Tracker(TrackerType.BYTETRACK, inflation_factor=1.8)
60
+ else:
61
+ tracker = Tracker(TrackerType.SORT, inflation_factor=1.5)
62
+
63
+ return tracker.update(detections)
64
+
65
+
66
+ def annotate(
67
+ image: np.ndarray,
68
+ detections: Detections,
69
+ labels: Optional[list] = None,
70
+ show_boxes: bool = True,
71
+ show_labels: bool = True
72
+ ) -> np.ndarray:
73
+ """Annotate image (one-liner).
74
+
75
+ Args:
76
+ image: Input image
77
+ detections: Detections to draw
78
+ labels: Custom labels (or None for auto)
79
+ show_boxes: Draw bounding boxes
80
+ show_labels: Draw labels
81
+
82
+ Returns:
83
+ Annotated image
84
+
85
+ Example:
86
+ >>> annotated = eye.annotate(frame, detections)
87
+ """
88
+ result = image.copy()
89
+
90
+ if show_boxes:
91
+ box_ann = BoxAnnotator()
92
+ result = box_ann.annotate(result, detections)
93
+
94
+ if show_labels:
95
+ if labels is None and detections.tracker_id is not None:
96
+ labels = [f"#{tid}" for tid in detections.tracker_id]
97
+ elif labels is None and detections.class_id is not None:
98
+ labels = [f"Class {cid}" for cid in detections.class_id]
99
+
100
+ if labels:
101
+ label_ann = LabelAnnotator()
102
+ result = label_ann.annotate(result, detections, labels)
103
+
104
+ return result
File without changes
File without changes
@@ -0,0 +1,386 @@
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+
5
+ from eye.detection.core import Detections
6
+ from eye.detection.utils import box_iou_batch
7
+ from eye.tracker.byte_tracker import matching
8
+ from eye.tracker.byte_tracker.kalman_filter import KalmanFilter
9
+ from eye.tracker.byte_tracker.single_object_track import STrack, TrackState
10
+ from eye.tracker.byte_tracker.utils import IdCounter
11
+
12
+
13
+ class ByteTrack:
14
+ """
15
+ Initialize the ByteTrack object.
16
+
17
+ <video controls>
18
+ <source src="https://media.roboflow.com/eye/video-examples/how-to/track-objects/annotate-video-with-traces.mp4" type="video/mp4">
19
+ </video>
20
+
21
+ Parameters:
22
+ track_activation_threshold (float): Detection confidence threshold
23
+ for track activation. Increasing track_activation_threshold improves accuracy
24
+ and stability but might miss true detections. Decreasing it increases
25
+ completeness but risks introducing noise and instability.
26
+ lost_track_buffer (int): Number of frames to buffer when a track is lost.
27
+ Increasing lost_track_buffer enhances occlusion handling, significantly
28
+ reducing the likelihood of track fragmentation or disappearance caused
29
+ by brief detection gaps.
30
+ minimum_matching_threshold (float): Threshold for matching tracks with detections.
31
+ Increasing minimum_matching_threshold improves accuracy but risks fragmentation.
32
+ Decreasing it improves completeness but risks false positives and drift.
33
+ frame_rate (int): The frame rate of the video.
34
+ minimum_consecutive_frames (int): Number of consecutive frames that an object must
35
+ be tracked before it is considered a 'valid' track.
36
+ Increasing minimum_consecutive_frames prevents the creation of accidental tracks from
37
+ false detection or double detection, but risks missing shorter tracks.
38
+ """ # noqa: E501 // docs
39
+
40
+ def __init__(
41
+ self,
42
+ track_activation_threshold: float = 0.25,
43
+ lost_track_buffer: int = 30,
44
+ minimum_matching_threshold: float = 0.8,
45
+ frame_rate: int = 30,
46
+ minimum_consecutive_frames: int = 1,
47
+ ):
48
+ self.track_activation_threshold = track_activation_threshold
49
+ self.minimum_matching_threshold = minimum_matching_threshold
50
+
51
+ self.frame_id = 0
52
+ self.det_thresh = self.track_activation_threshold + 0.1
53
+ self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer)
54
+ self.minimum_consecutive_frames = minimum_consecutive_frames
55
+ self.kalman_filter = KalmanFilter()
56
+ self.shared_kalman = KalmanFilter()
57
+
58
+ self.tracked_tracks: List[STrack] = []
59
+ self.lost_tracks: List[STrack] = []
60
+ self.removed_tracks: List[STrack] = []
61
+
62
+ # Warning, possible bug: If you also set internal_id to start at 1,
63
+ # all traces will be connected across objects.
64
+ self.internal_id_counter = IdCounter()
65
+ self.external_id_counter = IdCounter(start_id=1)
66
+
67
+ def update_with_detections(self, detections: Detections) -> Detections:
68
+ """
69
+ Updates the tracker with the provided detections and returns the updated
70
+ detection results.
71
+
72
+ Args:
73
+ detections (Detections): The detections to pass through the tracker.
74
+
75
+ Example:
76
+ ```python
77
+ import eye as sv
78
+ from ultralytics import YOLO
79
+
80
+ model = YOLO(<MODEL_PATH>)
81
+ tracker = sv.ByteTrack()
82
+
83
+ box_annotator = sv.BoxAnnotator()
84
+ label_annotator = sv.LabelAnnotator()
85
+
86
+ def callback(frame: np.ndarray, index: int) -> np.ndarray:
87
+ results = model(frame)[0]
88
+ detections = sv.Detections.from_ultralytics(results)
89
+ detections = tracker.update_with_detections(detections)
90
+
91
+ labels = [f"#{tracker_id}" for tracker_id in detections.tracker_id]
92
+
93
+ annotated_frame = box_annotator.annotate(
94
+ scene=frame.copy(), detections=detections)
95
+ annotated_frame = label_annotator.annotate(
96
+ scene=annotated_frame, detections=detections, labels=labels)
97
+ return annotated_frame
98
+
99
+ sv.process_video(
100
+ source_path=<SOURCE_VIDEO_PATH>,
101
+ target_path=<TARGET_VIDEO_PATH>,
102
+ callback=callback
103
+ )
104
+ ```
105
+ """
106
+ tensors = np.hstack(
107
+ (
108
+ detections.xyxy,
109
+ detections.confidence[:, np.newaxis],
110
+ )
111
+ )
112
+ tracks = self.update_with_tensors(tensors=tensors)
113
+
114
+ if len(tracks) > 0:
115
+ detection_bounding_boxes = np.asarray([det[:4] for det in tensors])
116
+ track_bounding_boxes = np.asarray([track.tlbr for track in tracks])
117
+
118
+ ious = box_iou_batch(detection_bounding_boxes, track_bounding_boxes)
119
+
120
+ iou_costs = 1 - ious
121
+
122
+ matches, _, _ = matching.linear_assignment(iou_costs, 0.5)
123
+ detections.tracker_id = np.full(len(detections), -1, dtype=int)
124
+ for i_detection, i_track in matches:
125
+ detections.tracker_id[i_detection] = int(
126
+ tracks[i_track].external_track_id
127
+ )
128
+
129
+ return detections[detections.tracker_id != -1]
130
+
131
+ else:
132
+ detections = Detections.empty()
133
+ detections.tracker_id = np.array([], dtype=int)
134
+
135
+ return detections
136
+
137
+ def reset(self) -> None:
138
+ """
139
+ Resets the internal state of the ByteTrack tracker.
140
+
141
+ This method clears the tracking data, including tracked, lost,
142
+ and removed tracks, as well as resetting the frame counter. It's
143
+ particularly useful when processing multiple videos sequentially,
144
+ ensuring the tracker starts with a clean state for each new video.
145
+ """
146
+ self.frame_id = 0
147
+ self.internal_id_counter.reset()
148
+ self.external_id_counter.reset()
149
+ self.tracked_tracks = []
150
+ self.lost_tracks = []
151
+ self.removed_tracks = []
152
+
153
+ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
154
+ """
155
+ Updates the tracker with the provided tensors and returns the updated tracks.
156
+
157
+ Parameters:
158
+ tensors: The new tensors to update with.
159
+
160
+ Returns:
161
+ List[STrack]: Updated tracks.
162
+ """
163
+ self.frame_id += 1
164
+ activated_starcks = []
165
+ refind_stracks = []
166
+ lost_stracks = []
167
+ removed_stracks = []
168
+
169
+ scores = tensors[:, 4]
170
+ bboxes = tensors[:, :4]
171
+
172
+ remain_inds = scores > self.track_activation_threshold
173
+ inds_low = scores > 0.1
174
+ inds_high = scores < self.track_activation_threshold
175
+
176
+ inds_second = np.logical_and(inds_low, inds_high)
177
+ dets_second = bboxes[inds_second]
178
+ dets = bboxes[remain_inds]
179
+ scores_keep = scores[remain_inds]
180
+ scores_second = scores[inds_second]
181
+
182
+ if len(dets) > 0:
183
+ """Detections"""
184
+ detections = [
185
+ STrack(
186
+ STrack.tlbr_to_tlwh(tlbr),
187
+ score_keep,
188
+ self.minimum_consecutive_frames,
189
+ self.shared_kalman,
190
+ self.internal_id_counter,
191
+ self.external_id_counter,
192
+ )
193
+ for (tlbr, score_keep) in zip(dets, scores_keep)
194
+ ]
195
+ else:
196
+ detections = []
197
+
198
+ """ Add newly detected tracklets to tracked_stracks"""
199
+ unconfirmed = []
200
+ tracked_stracks = [] # type: list[STrack]
201
+
202
+ for track in self.tracked_tracks:
203
+ if not track.is_activated:
204
+ unconfirmed.append(track)
205
+ else:
206
+ tracked_stracks.append(track)
207
+
208
+ """ Step 2: First association, with high score detection boxes"""
209
+ strack_pool = joint_tracks(tracked_stracks, self.lost_tracks)
210
+ # Predict the current location with KF
211
+ STrack.multi_predict(strack_pool, self.shared_kalman)
212
+ dists = matching.iou_distance(strack_pool, detections)
213
+
214
+ dists = matching.fuse_score(dists, detections)
215
+ matches, u_track, u_detection = matching.linear_assignment(
216
+ dists, thresh=self.minimum_matching_threshold
217
+ )
218
+
219
+ for itracked, idet in matches:
220
+ track = strack_pool[itracked]
221
+ det = detections[idet]
222
+ if track.state == TrackState.Tracked:
223
+ track.update(detections[idet], self.frame_id)
224
+ activated_starcks.append(track)
225
+ else:
226
+ track.re_activate(det, self.frame_id)
227
+ refind_stracks.append(track)
228
+
229
+ """ Step 3: Second association, with low score detection boxes"""
230
+ # association the untrack to the low score detections
231
+ if len(dets_second) > 0:
232
+ """Detections"""
233
+ detections_second = [
234
+ STrack(
235
+ STrack.tlbr_to_tlwh(tlbr),
236
+ score_second,
237
+ self.minimum_consecutive_frames,
238
+ self.shared_kalman,
239
+ self.internal_id_counter,
240
+ self.external_id_counter,
241
+ )
242
+ for (tlbr, score_second) in zip(dets_second, scores_second)
243
+ ]
244
+ else:
245
+ detections_second = []
246
+ r_tracked_stracks = [
247
+ strack_pool[i]
248
+ for i in u_track
249
+ if strack_pool[i].state == TrackState.Tracked
250
+ ]
251
+ dists = matching.iou_distance(r_tracked_stracks, detections_second)
252
+ matches, u_track, u_detection_second = matching.linear_assignment(
253
+ dists, thresh=0.5
254
+ )
255
+ for itracked, idet in matches:
256
+ track = r_tracked_stracks[itracked]
257
+ det = detections_second[idet]
258
+ if track.state == TrackState.Tracked:
259
+ track.update(det, self.frame_id)
260
+ activated_starcks.append(track)
261
+ else:
262
+ track.re_activate(det, self.frame_id)
263
+ refind_stracks.append(track)
264
+
265
+ for it in u_track:
266
+ track = r_tracked_stracks[it]
267
+ if not track.state == TrackState.Lost:
268
+ track.state = TrackState.Lost
269
+ lost_stracks.append(track)
270
+
271
+ """Deal with unconfirmed tracks, usually tracks with only one beginning frame"""
272
+ detections = [detections[i] for i in u_detection]
273
+ dists = matching.iou_distance(unconfirmed, detections)
274
+
275
+ dists = matching.fuse_score(dists, detections)
276
+ matches, u_unconfirmed, u_detection = matching.linear_assignment(
277
+ dists, thresh=0.7
278
+ )
279
+ for itracked, idet in matches:
280
+ unconfirmed[itracked].update(detections[idet], self.frame_id)
281
+ activated_starcks.append(unconfirmed[itracked])
282
+ for it in u_unconfirmed:
283
+ track = unconfirmed[it]
284
+ track.state = TrackState.Removed
285
+ removed_stracks.append(track)
286
+
287
+ """ Step 4: Init new stracks"""
288
+ for inew in u_detection:
289
+ track = detections[inew]
290
+ if track.score < self.det_thresh:
291
+ continue
292
+ track.activate(self.kalman_filter, self.frame_id)
293
+ activated_starcks.append(track)
294
+ """ Step 5: Update state"""
295
+ for track in self.lost_tracks:
296
+ if self.frame_id - track.frame_id > self.max_time_lost:
297
+ track.state = TrackState.Removed
298
+ removed_stracks.append(track)
299
+
300
+ self.tracked_tracks = [
301
+ t for t in self.tracked_tracks if t.state == TrackState.Tracked
302
+ ]
303
+ self.tracked_tracks = joint_tracks(self.tracked_tracks, activated_starcks)
304
+ self.tracked_tracks = joint_tracks(self.tracked_tracks, refind_stracks)
305
+ self.lost_tracks = sub_tracks(self.lost_tracks, self.tracked_tracks)
306
+ self.lost_tracks.extend(lost_stracks)
307
+ self.lost_tracks = sub_tracks(self.lost_tracks, self.removed_tracks)
308
+ self.removed_tracks = removed_stracks
309
+ self.tracked_tracks, self.lost_tracks = remove_duplicate_tracks(
310
+ self.tracked_tracks, self.lost_tracks
311
+ )
312
+ output_stracks = [track for track in self.tracked_tracks if track.is_activated]
313
+
314
+ return output_stracks
315
+
316
+
317
+ def joint_tracks(
318
+ track_list_a: List[STrack], track_list_b: List[STrack]
319
+ ) -> List[STrack]:
320
+ """
321
+ Joins two lists of tracks, ensuring that the resulting list does not
322
+ contain tracks with duplicate internal_track_id values.
323
+
324
+ Parameters:
325
+ track_list_a: First list of tracks (with internal_track_id attribute).
326
+ track_list_b: Second list of tracks (with internal_track_id attribute).
327
+
328
+ Returns:
329
+ Combined list of tracks from track_list_a and track_list_b
330
+ without duplicate internal_track_id values.
331
+ """
332
+ seen_track_ids = set()
333
+ result = []
334
+
335
+ for track in track_list_a + track_list_b:
336
+ if track.internal_track_id not in seen_track_ids:
337
+ seen_track_ids.add(track.internal_track_id)
338
+ result.append(track)
339
+
340
+ return result
341
+
342
+
343
+ def sub_tracks(track_list_a: List[STrack], track_list_b: List[STrack]) -> List[int]:
344
+ """
345
+ Returns a list of tracks from track_list_a after removing any tracks
346
+ that share the same internal_track_id with tracks in track_list_b.
347
+
348
+ Parameters:
349
+ track_list_a: List of tracks (with internal_track_id attribute).
350
+ track_list_b: List of tracks (with internal_track_id attribute) to
351
+ be subtracted from track_list_a.
352
+ Returns:
353
+ List of remaining tracks from track_list_a after subtraction.
354
+ """
355
+ tracks = {track.internal_track_id: track for track in track_list_a}
356
+ track_ids_b = {track.internal_track_id for track in track_list_b}
357
+
358
+ for track_id in track_ids_b:
359
+ tracks.pop(track_id, None)
360
+
361
+ return list(tracks.values())
362
+
363
+
364
+ def remove_duplicate_tracks(
365
+ tracks_a: List[STrack], tracks_b: List[STrack]
366
+ ) -> Tuple[List[STrack], List[STrack]]:
367
+ pairwise_distance = matching.iou_distance(tracks_a, tracks_b)
368
+ matching_pairs = np.where(pairwise_distance < 0.15)
369
+
370
+ duplicates_a, duplicates_b = set(), set()
371
+ for track_index_a, track_index_b in zip(*matching_pairs):
372
+ time_a = tracks_a[track_index_a].frame_id - tracks_a[track_index_a].start_frame
373
+ time_b = tracks_b[track_index_b].frame_id - tracks_b[track_index_b].start_frame
374
+ if time_a > time_b:
375
+ duplicates_b.add(track_index_b)
376
+ else:
377
+ duplicates_a.add(track_index_a)
378
+
379
+ result_a = [
380
+ track for index, track in enumerate(tracks_a) if index not in duplicates_a
381
+ ]
382
+ result_b = [
383
+ track for index, track in enumerate(tracks_b) if index not in duplicates_b
384
+ ]
385
+
386
+ return result_a, result_b