dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
3
7
|
import numpy as np
|
|
4
8
|
|
|
5
9
|
from ..utils import LOGGER
|
|
@@ -10,8 +14,7 @@ from .utils.kalman_filter import KalmanFilterXYAH
|
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
class STrack(BaseTrack):
|
|
13
|
-
"""
|
|
14
|
-
Single object tracking representation that uses Kalman filtering for state estimation.
|
|
17
|
+
"""Single object tracking representation that uses Kalman filtering for state estimation.
|
|
15
18
|
|
|
16
19
|
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
|
17
20
|
and predictions based on Kalman filter.
|
|
@@ -29,16 +32,17 @@ class STrack(BaseTrack):
|
|
|
29
32
|
idx (int): Index or identifier for the object.
|
|
30
33
|
frame_id (int): Current frame ID.
|
|
31
34
|
start_frame (int): Frame where the object was first detected.
|
|
35
|
+
angle (float | None): Optional angle information for oriented bounding boxes.
|
|
32
36
|
|
|
33
37
|
Methods:
|
|
34
|
-
predict
|
|
35
|
-
multi_predict
|
|
36
|
-
multi_gmc
|
|
37
|
-
activate
|
|
38
|
-
re_activate
|
|
39
|
-
update
|
|
40
|
-
convert_coords
|
|
41
|
-
tlwh_to_xyah
|
|
38
|
+
predict: Predict the next state of the object using Kalman filter.
|
|
39
|
+
multi_predict: Predict the next states for multiple tracks.
|
|
40
|
+
multi_gmc: Update multiple track states using a homography matrix.
|
|
41
|
+
activate: Activate a new tracklet.
|
|
42
|
+
re_activate: Reactivate a previously lost tracklet.
|
|
43
|
+
update: Update the state of a matched track.
|
|
44
|
+
convert_coords: Convert bounding box to x-y-aspect-height format.
|
|
45
|
+
tlwh_to_xyah: Convert tlwh bounding box to xyah format.
|
|
42
46
|
|
|
43
47
|
Examples:
|
|
44
48
|
Initialize and activate a new track
|
|
@@ -48,13 +52,12 @@ class STrack(BaseTrack):
|
|
|
48
52
|
|
|
49
53
|
shared_kalman = KalmanFilterXYAH()
|
|
50
54
|
|
|
51
|
-
def __init__(self, xywh, score, cls):
|
|
52
|
-
"""
|
|
53
|
-
Initialize a new STrack instance.
|
|
55
|
+
def __init__(self, xywh: list[float], score: float, cls: Any):
|
|
56
|
+
"""Initialize a new STrack instance.
|
|
54
57
|
|
|
55
58
|
Args:
|
|
56
|
-
xywh (
|
|
57
|
-
|
|
59
|
+
xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where (x,
|
|
60
|
+
y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
|
58
61
|
score (float): Confidence score of the detection.
|
|
59
62
|
cls (Any): Class label for the detected object.
|
|
60
63
|
|
|
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
|
|
|
79
82
|
self.angle = xywh[4] if len(xywh) == 6 else None
|
|
80
83
|
|
|
81
84
|
def predict(self):
|
|
82
|
-
"""
|
|
85
|
+
"""Predict the next state (mean and covariance) of the object using the Kalman filter."""
|
|
83
86
|
mean_state = self.mean.copy()
|
|
84
87
|
if self.state != TrackState.Tracked:
|
|
85
88
|
mean_state[7] = 0
|
|
86
89
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
|
87
90
|
|
|
88
91
|
@staticmethod
|
|
89
|
-
def multi_predict(stracks):
|
|
92
|
+
def multi_predict(stracks: list[STrack]):
|
|
90
93
|
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
|
91
94
|
if len(stracks) <= 0:
|
|
92
95
|
return
|
|
@@ -101,9 +104,9 @@ class STrack(BaseTrack):
|
|
|
101
104
|
stracks[i].covariance = cov
|
|
102
105
|
|
|
103
106
|
@staticmethod
|
|
104
|
-
def multi_gmc(stracks, H=np.eye(2, 3)):
|
|
107
|
+
def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3)):
|
|
105
108
|
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
|
106
|
-
if
|
|
109
|
+
if stracks:
|
|
107
110
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
|
108
111
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
|
109
112
|
|
|
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
|
|
|
119
122
|
stracks[i].mean = mean
|
|
120
123
|
stracks[i].covariance = cov
|
|
121
124
|
|
|
122
|
-
def activate(self, kalman_filter, frame_id):
|
|
125
|
+
def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
|
|
123
126
|
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
|
124
127
|
self.kalman_filter = kalman_filter
|
|
125
128
|
self.track_id = self.next_id()
|
|
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
|
|
|
132
135
|
self.frame_id = frame_id
|
|
133
136
|
self.start_frame = frame_id
|
|
134
137
|
|
|
135
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
|
136
|
-
"""
|
|
138
|
+
def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False):
|
|
139
|
+
"""Reactivate a previously lost track using new detection data and update its state and attributes."""
|
|
137
140
|
self.mean, self.covariance = self.kalman_filter.update(
|
|
138
141
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
|
139
142
|
)
|
|
@@ -148,9 +151,8 @@ class STrack(BaseTrack):
|
|
|
148
151
|
self.angle = new_track.angle
|
|
149
152
|
self.idx = new_track.idx
|
|
150
153
|
|
|
151
|
-
def update(self, new_track, frame_id):
|
|
152
|
-
"""
|
|
153
|
-
Update the state of a matched track.
|
|
154
|
+
def update(self, new_track: STrack, frame_id: int):
|
|
155
|
+
"""Update the state of a matched track.
|
|
154
156
|
|
|
155
157
|
Args:
|
|
156
158
|
new_track (STrack): The new track containing updated information.
|
|
@@ -177,13 +179,13 @@ class STrack(BaseTrack):
|
|
|
177
179
|
self.angle = new_track.angle
|
|
178
180
|
self.idx = new_track.idx
|
|
179
181
|
|
|
180
|
-
def convert_coords(self, tlwh):
|
|
182
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
|
181
183
|
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
|
182
184
|
return self.tlwh_to_xyah(tlwh)
|
|
183
185
|
|
|
184
186
|
@property
|
|
185
|
-
def tlwh(self):
|
|
186
|
-
"""
|
|
187
|
+
def tlwh(self) -> np.ndarray:
|
|
188
|
+
"""Get the bounding box in top-left-width-height format from the current state estimate."""
|
|
187
189
|
if self.mean is None:
|
|
188
190
|
return self._tlwh.copy()
|
|
189
191
|
ret = self.mean[:4].copy()
|
|
@@ -192,14 +194,14 @@ class STrack(BaseTrack):
|
|
|
192
194
|
return ret
|
|
193
195
|
|
|
194
196
|
@property
|
|
195
|
-
def xyxy(self):
|
|
196
|
-
"""
|
|
197
|
+
def xyxy(self) -> np.ndarray:
|
|
198
|
+
"""Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
|
197
199
|
ret = self.tlwh.copy()
|
|
198
200
|
ret[2:] += ret[:2]
|
|
199
201
|
return ret
|
|
200
202
|
|
|
201
203
|
@staticmethod
|
|
202
|
-
def tlwh_to_xyah(tlwh):
|
|
204
|
+
def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
|
|
203
205
|
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
|
204
206
|
ret = np.asarray(tlwh).copy()
|
|
205
207
|
ret[:2] += ret[2:] / 2
|
|
@@ -207,58 +209,58 @@ class STrack(BaseTrack):
|
|
|
207
209
|
return ret
|
|
208
210
|
|
|
209
211
|
@property
|
|
210
|
-
def xywh(self):
|
|
211
|
-
"""
|
|
212
|
+
def xywh(self) -> np.ndarray:
|
|
213
|
+
"""Get the current position of the bounding box in (center x, center y, width, height) format."""
|
|
212
214
|
ret = np.asarray(self.tlwh).copy()
|
|
213
215
|
ret[:2] += ret[2:] / 2
|
|
214
216
|
return ret
|
|
215
217
|
|
|
216
218
|
@property
|
|
217
|
-
def xywha(self):
|
|
218
|
-
"""
|
|
219
|
+
def xywha(self) -> np.ndarray:
|
|
220
|
+
"""Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
|
219
221
|
if self.angle is None:
|
|
220
222
|
LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
|
|
221
223
|
return self.xywh
|
|
222
224
|
return np.concatenate([self.xywh, self.angle[None]])
|
|
223
225
|
|
|
224
226
|
@property
|
|
225
|
-
def result(self):
|
|
226
|
-
"""
|
|
227
|
+
def result(self) -> list[float]:
|
|
228
|
+
"""Get the current tracking results in the appropriate bounding box format."""
|
|
227
229
|
coords = self.xyxy if self.angle is None else self.xywha
|
|
228
|
-
return coords.tolist()
|
|
230
|
+
return [*coords.tolist(), self.track_id, self.score, self.cls, self.idx]
|
|
229
231
|
|
|
230
|
-
def __repr__(self):
|
|
231
|
-
"""
|
|
232
|
+
def __repr__(self) -> str:
|
|
233
|
+
"""Return a string representation of the STrack object including start frame, end frame, and track ID."""
|
|
232
234
|
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
|
233
235
|
|
|
234
236
|
|
|
235
237
|
class BYTETracker:
|
|
236
|
-
"""
|
|
237
|
-
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
|
238
|
+
"""BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
|
238
239
|
|
|
239
|
-
This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
|
|
240
|
-
video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
|
|
241
|
-
predicting the new object locations, and performs data association.
|
|
240
|
+
This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
|
|
241
|
+
in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
|
|
242
|
+
filtering for predicting the new object locations, and performs data association.
|
|
242
243
|
|
|
243
244
|
Attributes:
|
|
244
|
-
tracked_stracks (
|
|
245
|
-
lost_stracks (
|
|
246
|
-
removed_stracks (
|
|
245
|
+
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
|
246
|
+
lost_stracks (list[STrack]): List of lost tracks.
|
|
247
|
+
removed_stracks (list[STrack]): List of removed tracks.
|
|
247
248
|
frame_id (int): The current frame ID.
|
|
248
249
|
args (Namespace): Command-line arguments.
|
|
249
250
|
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
|
250
251
|
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
|
251
252
|
|
|
252
253
|
Methods:
|
|
253
|
-
update
|
|
254
|
-
get_kalmanfilter
|
|
255
|
-
init_track
|
|
256
|
-
get_dists
|
|
257
|
-
multi_predict
|
|
258
|
-
reset_id
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
254
|
+
update: Update object tracker with new detections.
|
|
255
|
+
get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
|
|
256
|
+
init_track: Initialize object tracking with detections.
|
|
257
|
+
get_dists: Calculate the distance between tracks and detections.
|
|
258
|
+
multi_predict: Predict the location of tracks.
|
|
259
|
+
reset_id: Reset the ID counter of STrack.
|
|
260
|
+
reset: Reset the tracker by clearing all tracks.
|
|
261
|
+
joint_stracks: Combine two lists of stracks.
|
|
262
|
+
sub_stracks: Filter out the stracks present in the second list from the first list.
|
|
263
|
+
remove_duplicate_stracks: Remove duplicate stracks based on IoU.
|
|
262
264
|
|
|
263
265
|
Examples:
|
|
264
266
|
Initialize BYTETracker and update with detection results
|
|
@@ -267,9 +269,8 @@ class BYTETracker:
|
|
|
267
269
|
>>> tracked_objects = tracker.update(results)
|
|
268
270
|
"""
|
|
269
271
|
|
|
270
|
-
def __init__(self, args, frame_rate=30):
|
|
271
|
-
"""
|
|
272
|
-
Initialize a BYTETracker instance for object tracking.
|
|
272
|
+
def __init__(self, args, frame_rate: int = 30):
|
|
273
|
+
"""Initialize a BYTETracker instance for object tracking.
|
|
273
274
|
|
|
274
275
|
Args:
|
|
275
276
|
args (Namespace): Command-line arguments containing tracking parameters.
|
|
@@ -290,8 +291,8 @@ class BYTETracker:
|
|
|
290
291
|
self.kalman_filter = self.get_kalmanfilter()
|
|
291
292
|
self.reset_id()
|
|
292
293
|
|
|
293
|
-
def update(self, results, img=None, feats=None):
|
|
294
|
-
"""
|
|
294
|
+
def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray:
|
|
295
|
+
"""Update the tracker with new detections and return the current list of tracked objects."""
|
|
295
296
|
self.frame_id += 1
|
|
296
297
|
activated_stracks = []
|
|
297
298
|
refind_stracks = []
|
|
@@ -299,24 +300,19 @@ class BYTETracker:
|
|
|
299
300
|
removed_stracks = []
|
|
300
301
|
|
|
301
302
|
scores = results.conf
|
|
302
|
-
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
|
303
|
-
# Add index
|
|
304
|
-
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
|
305
|
-
cls = results.cls
|
|
306
|
-
|
|
307
303
|
remain_inds = scores >= self.args.track_high_thresh
|
|
308
304
|
inds_low = scores > self.args.track_low_thresh
|
|
309
305
|
inds_high = scores < self.args.track_high_thresh
|
|
310
306
|
|
|
311
307
|
inds_second = inds_low & inds_high
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
detections = self.init_track(
|
|
308
|
+
results_second = results[inds_second]
|
|
309
|
+
results = results[remain_inds]
|
|
310
|
+
feats_keep = feats_second = img
|
|
311
|
+
if feats is not None and len(feats):
|
|
312
|
+
feats_keep = feats[remain_inds]
|
|
313
|
+
feats_second = feats[inds_second]
|
|
314
|
+
|
|
315
|
+
detections = self.init_track(results, feats_keep)
|
|
320
316
|
# Add newly detected tracklets to tracked_stracks
|
|
321
317
|
unconfirmed = []
|
|
322
318
|
tracked_stracks = [] # type: list[STrack]
|
|
@@ -332,7 +328,7 @@ class BYTETracker:
|
|
|
332
328
|
if hasattr(self, "gmc") and img is not None:
|
|
333
329
|
# use try-except here to bypass errors from gmc module
|
|
334
330
|
try:
|
|
335
|
-
warp = self.gmc.apply(img,
|
|
331
|
+
warp = self.gmc.apply(img, results.xyxy)
|
|
336
332
|
except Exception:
|
|
337
333
|
warp = np.eye(2, 3)
|
|
338
334
|
STrack.multi_gmc(strack_pool, warp)
|
|
@@ -351,11 +347,11 @@ class BYTETracker:
|
|
|
351
347
|
track.re_activate(det, self.frame_id, new_id=False)
|
|
352
348
|
refind_stracks.append(track)
|
|
353
349
|
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
|
|
354
|
-
detections_second = self.init_track(
|
|
350
|
+
detections_second = self.init_track(results_second, feats_second)
|
|
355
351
|
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
|
356
352
|
# TODO
|
|
357
353
|
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
|
358
|
-
matches, u_track,
|
|
354
|
+
matches, u_track, _u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
|
359
355
|
for itracked, idet in matches:
|
|
360
356
|
track = r_tracked_stracks[itracked]
|
|
361
357
|
det = detections_second[idet]
|
|
@@ -408,32 +404,36 @@ class BYTETracker:
|
|
|
408
404
|
|
|
409
405
|
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
|
410
406
|
|
|
411
|
-
def get_kalmanfilter(self):
|
|
412
|
-
"""
|
|
407
|
+
def get_kalmanfilter(self) -> KalmanFilterXYAH:
|
|
408
|
+
"""Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
|
413
409
|
return KalmanFilterXYAH()
|
|
414
410
|
|
|
415
|
-
def init_track(self,
|
|
416
|
-
"""
|
|
417
|
-
|
|
411
|
+
def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]:
|
|
412
|
+
"""Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
|
413
|
+
if len(results) == 0:
|
|
414
|
+
return []
|
|
415
|
+
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
|
416
|
+
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
|
417
|
+
return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
|
|
418
418
|
|
|
419
|
-
def get_dists(self, tracks, detections):
|
|
420
|
-
"""
|
|
419
|
+
def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray:
|
|
420
|
+
"""Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
|
|
421
421
|
dists = matching.iou_distance(tracks, detections)
|
|
422
422
|
if self.args.fuse_score:
|
|
423
423
|
dists = matching.fuse_score(dists, detections)
|
|
424
424
|
return dists
|
|
425
425
|
|
|
426
|
-
def multi_predict(self, tracks):
|
|
426
|
+
def multi_predict(self, tracks: list[STrack]):
|
|
427
427
|
"""Predict the next states for multiple tracks using Kalman filter."""
|
|
428
428
|
STrack.multi_predict(tracks)
|
|
429
429
|
|
|
430
430
|
@staticmethod
|
|
431
431
|
def reset_id():
|
|
432
|
-
"""
|
|
432
|
+
"""Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
|
433
433
|
STrack.reset_id()
|
|
434
434
|
|
|
435
435
|
def reset(self):
|
|
436
|
-
"""
|
|
436
|
+
"""Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
|
437
437
|
self.tracked_stracks = [] # type: list[STrack]
|
|
438
438
|
self.lost_stracks = [] # type: list[STrack]
|
|
439
439
|
self.removed_stracks = [] # type: list[STrack]
|
|
@@ -442,8 +442,8 @@ class BYTETracker:
|
|
|
442
442
|
self.reset_id()
|
|
443
443
|
|
|
444
444
|
@staticmethod
|
|
445
|
-
def joint_stracks(tlista, tlistb):
|
|
446
|
-
"""
|
|
445
|
+
def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
|
|
446
|
+
"""Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
|
447
447
|
exists = {}
|
|
448
448
|
res = []
|
|
449
449
|
for t in tlista:
|
|
@@ -457,14 +457,14 @@ class BYTETracker:
|
|
|
457
457
|
return res
|
|
458
458
|
|
|
459
459
|
@staticmethod
|
|
460
|
-
def sub_stracks(tlista, tlistb):
|
|
461
|
-
"""
|
|
460
|
+
def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
|
|
461
|
+
"""Filter out the stracks present in the second list from the first list."""
|
|
462
462
|
track_ids_b = {t.track_id for t in tlistb}
|
|
463
463
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
|
464
464
|
|
|
465
465
|
@staticmethod
|
|
466
|
-
def remove_duplicate_stracks(stracksa, stracksb):
|
|
467
|
-
"""
|
|
466
|
+
def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]:
|
|
467
|
+
"""Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
|
468
468
|
pdist = matching.iou_distance(stracksa, stracksb)
|
|
469
469
|
pairs = np.where(pdist < 0.15)
|
|
470
470
|
dupa, dupb = [], []
|
ultralytics/trackers/track.py
CHANGED
|
@@ -16,19 +16,14 @@ TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|
19
|
-
"""
|
|
20
|
-
Initialize trackers for object tracking during prediction.
|
|
19
|
+
"""Initialize trackers for object tracking during prediction.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
predictor (
|
|
24
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
|
25
|
-
|
|
26
|
-
Raises:
|
|
27
|
-
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
|
28
|
-
ValueError: If the task is 'classify' as classification doesn't support tracking.
|
|
22
|
+
predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
|
|
23
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
|
29
24
|
|
|
30
25
|
Examples:
|
|
31
|
-
Initialize trackers for a predictor object
|
|
26
|
+
Initialize trackers for a predictor object
|
|
32
27
|
>>> predictor = SomePredictorClass()
|
|
33
28
|
>>> on_predict_start(predictor, persist=True)
|
|
34
29
|
"""
|
|
@@ -74,12 +69,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|
|
74
69
|
|
|
75
70
|
|
|
76
71
|
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
|
|
77
|
-
"""
|
|
78
|
-
Postprocess detected boxes and update with object tracking.
|
|
72
|
+
"""Postprocess detected boxes and update with object tracking.
|
|
79
73
|
|
|
80
74
|
Args:
|
|
81
75
|
predictor (object): The predictor object containing the predictions.
|
|
82
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
|
76
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
|
83
77
|
|
|
84
78
|
Examples:
|
|
85
79
|
Postprocess predictions and update with tracking
|
|
@@ -96,8 +90,6 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|
|
96
90
|
predictor.vid_path[i if is_stream else 0] = vid_path
|
|
97
91
|
|
|
98
92
|
det = (result.obb if is_obb else result.boxes).cpu().numpy()
|
|
99
|
-
if len(det) == 0:
|
|
100
|
-
continue
|
|
101
93
|
tracks = tracker.update(det, result.orig_img, getattr(result, "feats", None))
|
|
102
94
|
if len(tracks) == 0:
|
|
103
95
|
continue
|
|
@@ -109,8 +101,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|
|
109
101
|
|
|
110
102
|
|
|
111
103
|
def register_tracker(model: object, persist: bool) -> None:
|
|
112
|
-
"""
|
|
113
|
-
Register tracking callbacks to the model for object tracking during prediction.
|
|
104
|
+
"""Register tracking callbacks to the model for object tracking during prediction.
|
|
114
105
|
|
|
115
106
|
Args:
|
|
116
107
|
model (object): The model object to register tracking callbacks for.
|