dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
- dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +22 -19
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -158
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +13 -11
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +18 -12
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +84 -42
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from typing import Any, List, Optional, Tuple
|
4
|
+
|
3
5
|
import numpy as np
|
4
6
|
|
5
7
|
from ..utils import LOGGER
|
@@ -29,16 +31,17 @@ class STrack(BaseTrack):
|
|
29
31
|
idx (int): Index or identifier for the object.
|
30
32
|
frame_id (int): Current frame ID.
|
31
33
|
start_frame (int): Frame where the object was first detected.
|
34
|
+
angle (float | None): Optional angle information for oriented bounding boxes.
|
32
35
|
|
33
36
|
Methods:
|
34
|
-
predict
|
35
|
-
multi_predict
|
36
|
-
multi_gmc
|
37
|
-
activate
|
38
|
-
re_activate
|
39
|
-
update
|
40
|
-
convert_coords
|
41
|
-
tlwh_to_xyah
|
37
|
+
predict: Predict the next state of the object using Kalman filter.
|
38
|
+
multi_predict: Predict the next states for multiple tracks.
|
39
|
+
multi_gmc: Update multiple track states using a homography matrix.
|
40
|
+
activate: Activate a new tracklet.
|
41
|
+
re_activate: Reactivate a previously lost tracklet.
|
42
|
+
update: Update the state of a matched track.
|
43
|
+
convert_coords: Convert bounding box to x-y-aspect-height format.
|
44
|
+
tlwh_to_xyah: Convert tlwh bounding box to xyah format.
|
42
45
|
|
43
46
|
Examples:
|
44
47
|
Initialize and activate a new track
|
@@ -48,7 +51,7 @@ class STrack(BaseTrack):
|
|
48
51
|
|
49
52
|
shared_kalman = KalmanFilterXYAH()
|
50
53
|
|
51
|
-
def __init__(self, xywh, score, cls):
|
54
|
+
def __init__(self, xywh: List[float], score: float, cls: Any):
|
52
55
|
"""
|
53
56
|
Initialize a new STrack instance.
|
54
57
|
|
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
|
|
79
82
|
self.angle = xywh[4] if len(xywh) == 6 else None
|
80
83
|
|
81
84
|
def predict(self):
|
82
|
-
"""
|
85
|
+
"""Predict the next state (mean and covariance) of the object using the Kalman filter."""
|
83
86
|
mean_state = self.mean.copy()
|
84
87
|
if self.state != TrackState.Tracked:
|
85
88
|
mean_state[7] = 0
|
86
89
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
87
90
|
|
88
91
|
@staticmethod
|
89
|
-
def multi_predict(stracks):
|
92
|
+
def multi_predict(stracks: List["STrack"]):
|
90
93
|
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
91
94
|
if len(stracks) <= 0:
|
92
95
|
return
|
@@ -101,7 +104,7 @@ class STrack(BaseTrack):
|
|
101
104
|
stracks[i].covariance = cov
|
102
105
|
|
103
106
|
@staticmethod
|
104
|
-
def multi_gmc(stracks, H=np.eye(2, 3)):
|
107
|
+
def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
|
105
108
|
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
106
109
|
if len(stracks) > 0:
|
107
110
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
|
|
119
122
|
stracks[i].mean = mean
|
120
123
|
stracks[i].covariance = cov
|
121
124
|
|
122
|
-
def activate(self, kalman_filter, frame_id):
|
125
|
+
def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
|
123
126
|
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
124
127
|
self.kalman_filter = kalman_filter
|
125
128
|
self.track_id = self.next_id()
|
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
|
|
132
135
|
self.frame_id = frame_id
|
133
136
|
self.start_frame = frame_id
|
134
137
|
|
135
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
136
|
-
"""
|
138
|
+
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
|
139
|
+
"""Reactivate a previously lost track using new detection data and update its state and attributes."""
|
137
140
|
self.mean, self.covariance = self.kalman_filter.update(
|
138
141
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
139
142
|
)
|
@@ -148,7 +151,7 @@ class STrack(BaseTrack):
|
|
148
151
|
self.angle = new_track.angle
|
149
152
|
self.idx = new_track.idx
|
150
153
|
|
151
|
-
def update(self, new_track, frame_id):
|
154
|
+
def update(self, new_track: "STrack", frame_id: int):
|
152
155
|
"""
|
153
156
|
Update the state of a matched track.
|
154
157
|
|
@@ -177,13 +180,13 @@ class STrack(BaseTrack):
|
|
177
180
|
self.angle = new_track.angle
|
178
181
|
self.idx = new_track.idx
|
179
182
|
|
180
|
-
def convert_coords(self, tlwh):
|
183
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
181
184
|
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
182
185
|
return self.tlwh_to_xyah(tlwh)
|
183
186
|
|
184
187
|
@property
|
185
|
-
def tlwh(self):
|
186
|
-
"""
|
188
|
+
def tlwh(self) -> np.ndarray:
|
189
|
+
"""Get the bounding box in top-left-width-height format from the current state estimate."""
|
187
190
|
if self.mean is None:
|
188
191
|
return self._tlwh.copy()
|
189
192
|
ret = self.mean[:4].copy()
|
@@ -192,14 +195,14 @@ class STrack(BaseTrack):
|
|
192
195
|
return ret
|
193
196
|
|
194
197
|
@property
|
195
|
-
def xyxy(self):
|
196
|
-
"""
|
198
|
+
def xyxy(self) -> np.ndarray:
|
199
|
+
"""Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
197
200
|
ret = self.tlwh.copy()
|
198
201
|
ret[2:] += ret[:2]
|
199
202
|
return ret
|
200
203
|
|
201
204
|
@staticmethod
|
202
|
-
def tlwh_to_xyah(tlwh):
|
205
|
+
def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
|
203
206
|
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
204
207
|
ret = np.asarray(tlwh).copy()
|
205
208
|
ret[:2] += ret[2:] / 2
|
@@ -207,28 +210,28 @@ class STrack(BaseTrack):
|
|
207
210
|
return ret
|
208
211
|
|
209
212
|
@property
|
210
|
-
def xywh(self):
|
211
|
-
"""
|
213
|
+
def xywh(self) -> np.ndarray:
|
214
|
+
"""Get the current position of the bounding box in (center x, center y, width, height) format."""
|
212
215
|
ret = np.asarray(self.tlwh).copy()
|
213
216
|
ret[:2] += ret[2:] / 2
|
214
217
|
return ret
|
215
218
|
|
216
219
|
@property
|
217
|
-
def xywha(self):
|
218
|
-
"""
|
220
|
+
def xywha(self) -> np.ndarray:
|
221
|
+
"""Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
219
222
|
if self.angle is None:
|
220
223
|
LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
|
221
224
|
return self.xywh
|
222
225
|
return np.concatenate([self.xywh, self.angle[None]])
|
223
226
|
|
224
227
|
@property
|
225
|
-
def result(self):
|
226
|
-
"""
|
228
|
+
def result(self) -> List[float]:
|
229
|
+
"""Get the current tracking results in the appropriate bounding box format."""
|
227
230
|
coords = self.xyxy if self.angle is None else self.xywha
|
228
231
|
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
229
232
|
|
230
|
-
def __repr__(self):
|
231
|
-
"""
|
233
|
+
def __repr__(self) -> str:
|
234
|
+
"""Return a string representation of the STrack object including start frame, end frame, and track ID."""
|
232
235
|
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
233
236
|
|
234
237
|
|
@@ -250,15 +253,16 @@ class BYTETracker:
|
|
250
253
|
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
251
254
|
|
252
255
|
Methods:
|
253
|
-
update
|
254
|
-
get_kalmanfilter
|
255
|
-
init_track
|
256
|
-
get_dists
|
257
|
-
multi_predict
|
258
|
-
reset_id
|
259
|
-
|
260
|
-
|
261
|
-
|
256
|
+
update: Update object tracker with new detections.
|
257
|
+
get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
|
258
|
+
init_track: Initialize object tracking with detections.
|
259
|
+
get_dists: Calculate the distance between tracks and detections.
|
260
|
+
multi_predict: Predict the location of tracks.
|
261
|
+
reset_id: Reset the ID counter of STrack.
|
262
|
+
reset: Reset the tracker by clearing all tracks.
|
263
|
+
joint_stracks: Combine two lists of stracks.
|
264
|
+
sub_stracks: Filter out the stracks present in the second list from the first list.
|
265
|
+
remove_duplicate_stracks: Remove duplicate stracks based on IoU.
|
262
266
|
|
263
267
|
Examples:
|
264
268
|
Initialize BYTETracker and update with detection results
|
@@ -267,7 +271,7 @@ class BYTETracker:
|
|
267
271
|
>>> tracked_objects = tracker.update(results)
|
268
272
|
"""
|
269
273
|
|
270
|
-
def __init__(self, args, frame_rate=30):
|
274
|
+
def __init__(self, args, frame_rate: int = 30):
|
271
275
|
"""
|
272
276
|
Initialize a BYTETracker instance for object tracking.
|
273
277
|
|
@@ -280,9 +284,9 @@ class BYTETracker:
|
|
280
284
|
>>> args = Namespace(track_buffer=30)
|
281
285
|
>>> tracker = BYTETracker(args, frame_rate=30)
|
282
286
|
"""
|
283
|
-
self.tracked_stracks = [] # type:
|
284
|
-
self.lost_stracks = [] # type:
|
285
|
-
self.removed_stracks = [] # type:
|
287
|
+
self.tracked_stracks = [] # type: List[STrack]
|
288
|
+
self.lost_stracks = [] # type: List[STrack]
|
289
|
+
self.removed_stracks = [] # type: List[STrack]
|
286
290
|
|
287
291
|
self.frame_id = 0
|
288
292
|
self.args = args
|
@@ -290,8 +294,8 @@ class BYTETracker:
|
|
290
294
|
self.kalman_filter = self.get_kalmanfilter()
|
291
295
|
self.reset_id()
|
292
296
|
|
293
|
-
def update(self, results, img=None, feats=None):
|
294
|
-
"""
|
297
|
+
def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:
|
298
|
+
"""Update the tracker with new detections and return the current list of tracked objects."""
|
295
299
|
self.frame_id += 1
|
296
300
|
activated_stracks = []
|
297
301
|
refind_stracks = []
|
@@ -319,7 +323,7 @@ class BYTETracker:
|
|
319
323
|
detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)
|
320
324
|
# Add newly detected tracklets to tracked_stracks
|
321
325
|
unconfirmed = []
|
322
|
-
tracked_stracks = [] # type:
|
326
|
+
tracked_stracks = [] # type: List[STrack]
|
323
327
|
for track in self.tracked_stracks:
|
324
328
|
if not track.is_activated:
|
325
329
|
unconfirmed.append(track)
|
@@ -408,42 +412,44 @@ class BYTETracker:
|
|
408
412
|
|
409
413
|
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
410
414
|
|
411
|
-
def get_kalmanfilter(self):
|
412
|
-
"""
|
415
|
+
def get_kalmanfilter(self) -> KalmanFilterXYAH:
|
416
|
+
"""Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
413
417
|
return KalmanFilterXYAH()
|
414
418
|
|
415
|
-
def init_track(
|
416
|
-
|
419
|
+
def init_track(
|
420
|
+
self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
|
421
|
+
) -> List[STrack]:
|
422
|
+
"""Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
417
423
|
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
418
424
|
|
419
|
-
def get_dists(self, tracks, detections):
|
420
|
-
"""
|
425
|
+
def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:
|
426
|
+
"""Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
|
421
427
|
dists = matching.iou_distance(tracks, detections)
|
422
428
|
if self.args.fuse_score:
|
423
429
|
dists = matching.fuse_score(dists, detections)
|
424
430
|
return dists
|
425
431
|
|
426
|
-
def multi_predict(self, tracks):
|
432
|
+
def multi_predict(self, tracks: List[STrack]):
|
427
433
|
"""Predict the next states for multiple tracks using Kalman filter."""
|
428
434
|
STrack.multi_predict(tracks)
|
429
435
|
|
430
436
|
@staticmethod
|
431
437
|
def reset_id():
|
432
|
-
"""
|
438
|
+
"""Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
433
439
|
STrack.reset_id()
|
434
440
|
|
435
441
|
def reset(self):
|
436
|
-
"""
|
437
|
-
self.tracked_stracks = [] # type:
|
438
|
-
self.lost_stracks = [] # type:
|
439
|
-
self.removed_stracks = [] # type:
|
442
|
+
"""Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
443
|
+
self.tracked_stracks = [] # type: List[STrack]
|
444
|
+
self.lost_stracks = [] # type: List[STrack]
|
445
|
+
self.removed_stracks = [] # type: List[STrack]
|
440
446
|
self.frame_id = 0
|
441
447
|
self.kalman_filter = self.get_kalmanfilter()
|
442
448
|
self.reset_id()
|
443
449
|
|
444
450
|
@staticmethod
|
445
|
-
def joint_stracks(tlista, tlistb):
|
446
|
-
"""
|
451
|
+
def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
|
452
|
+
"""Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
447
453
|
exists = {}
|
448
454
|
res = []
|
449
455
|
for t in tlista:
|
@@ -457,14 +463,14 @@ class BYTETracker:
|
|
457
463
|
return res
|
458
464
|
|
459
465
|
@staticmethod
|
460
|
-
def sub_stracks(tlista, tlistb):
|
461
|
-
"""
|
466
|
+
def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
|
467
|
+
"""Filter out the stracks present in the second list from the first list."""
|
462
468
|
track_ids_b = {t.track_id for t in tlistb}
|
463
469
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
464
470
|
|
465
471
|
@staticmethod
|
466
|
-
def remove_duplicate_stracks(stracksa, stracksb):
|
467
|
-
"""
|
472
|
+
def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:
|
473
|
+
"""Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
468
474
|
pdist = matching.iou_distance(stracksa, stracksb)
|
469
475
|
pairs = np.where(pdist < 0.15)
|
470
476
|
dupa, dupb = [], []
|
ultralytics/trackers/track.py
CHANGED
@@ -20,15 +20,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|
20
20
|
Initialize trackers for object tracking during prediction.
|
21
21
|
|
22
22
|
Args:
|
23
|
-
predictor (
|
24
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
25
|
-
|
26
|
-
Raises:
|
27
|
-
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
28
|
-
ValueError: If the task is 'classify' as classification doesn't support tracking.
|
23
|
+
predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
|
24
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
29
25
|
|
30
26
|
Examples:
|
31
|
-
Initialize trackers for a predictor object
|
27
|
+
Initialize trackers for a predictor object
|
32
28
|
>>> predictor = SomePredictorClass()
|
33
29
|
>>> on_predict_start(predictor, persist=True)
|
34
30
|
"""
|
@@ -79,7 +75,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|
79
75
|
|
80
76
|
Args:
|
81
77
|
predictor (object): The predictor object containing the predictions.
|
82
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
78
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
83
79
|
|
84
80
|
Examples:
|
85
81
|
Postprocess predictions and update with tracking
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import copy
|
4
|
+
from typing import List, Optional
|
4
5
|
|
5
6
|
import cv2
|
6
7
|
import numpy as np
|
@@ -19,7 +20,7 @@ class GMC:
|
|
19
20
|
method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
20
21
|
downscale (int): Factor by which to downscale the frames for processing.
|
21
22
|
prevFrame (np.ndarray): Previous frame for tracking.
|
22
|
-
prevKeyPoints (
|
23
|
+
prevKeyPoints (List): Keypoints from the previous frame.
|
23
24
|
prevDescriptors (np.ndarray): Descriptors from the previous frame.
|
24
25
|
initializedFirstFrame (bool): Flag indicating if the first frame has been processed.
|
25
26
|
|
@@ -88,13 +89,13 @@ class GMC:
|
|
88
89
|
self.prevDescriptors = None
|
89
90
|
self.initializedFirstFrame = False
|
90
91
|
|
91
|
-
def apply(self, raw_frame: np.ndarray, detections:
|
92
|
+
def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
|
92
93
|
"""
|
93
94
|
Apply object detection on a raw frame using the specified method.
|
94
95
|
|
95
96
|
Args:
|
96
97
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
97
|
-
detections (List
|
98
|
+
detections (List, optional): List of detections to be used in the processing.
|
98
99
|
|
99
100
|
Returns:
|
100
101
|
(np.ndarray): Transformation matrix with shape (2, 3).
|
@@ -136,23 +137,18 @@ class GMC:
|
|
136
137
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
137
138
|
H = np.eye(2, 3, dtype=np.float32)
|
138
139
|
|
139
|
-
# Downscale image
|
140
|
+
# Downscale image for computational efficiency
|
140
141
|
if self.downscale > 1.0:
|
141
142
|
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
142
143
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
143
144
|
|
144
|
-
# Handle first frame
|
145
|
+
# Handle first frame initialization
|
145
146
|
if not self.initializedFirstFrame:
|
146
|
-
# Initialize data
|
147
147
|
self.prevFrame = frame.copy()
|
148
|
-
|
149
|
-
# Initialization done
|
150
148
|
self.initializedFirstFrame = True
|
151
|
-
|
152
149
|
return H
|
153
150
|
|
154
|
-
# Run the ECC algorithm
|
155
|
-
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
151
|
+
# Run the ECC algorithm to find transformation matrix
|
156
152
|
try:
|
157
153
|
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
158
154
|
except Exception as e:
|
@@ -160,13 +156,13 @@ class GMC:
|
|
160
156
|
|
161
157
|
return H
|
162
158
|
|
163
|
-
def apply_features(self, raw_frame: np.ndarray, detections:
|
159
|
+
def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
|
164
160
|
"""
|
165
161
|
Apply feature-based methods like ORB or SIFT to a raw frame.
|
166
162
|
|
167
163
|
Args:
|
168
164
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
169
|
-
detections (List
|
165
|
+
detections (List, optional): List of detections to be used in the processing.
|
170
166
|
|
171
167
|
Returns:
|
172
168
|
(np.ndarray): Transformation matrix with shape (2, 3).
|
@@ -182,55 +178,50 @@ class GMC:
|
|
182
178
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
183
179
|
H = np.eye(2, 3)
|
184
180
|
|
185
|
-
# Downscale image
|
181
|
+
# Downscale image for computational efficiency
|
186
182
|
if self.downscale > 1.0:
|
187
183
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
188
184
|
width = width // self.downscale
|
189
185
|
height = height // self.downscale
|
190
186
|
|
191
|
-
#
|
187
|
+
# Create mask for keypoint detection, excluding border regions
|
192
188
|
mask = np.zeros_like(frame)
|
193
189
|
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
|
190
|
+
|
191
|
+
# Exclude detection regions from mask to avoid tracking detected objects
|
194
192
|
if detections is not None:
|
195
193
|
for det in detections:
|
196
194
|
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
197
195
|
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
|
198
196
|
|
197
|
+
# Find keypoints and compute descriptors
|
199
198
|
keypoints = self.detector.detect(frame, mask)
|
200
|
-
|
201
|
-
# Compute the descriptors
|
202
199
|
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
203
200
|
|
204
|
-
# Handle first frame
|
201
|
+
# Handle first frame initialization
|
205
202
|
if not self.initializedFirstFrame:
|
206
|
-
# Initialize data
|
207
203
|
self.prevFrame = frame.copy()
|
208
204
|
self.prevKeyPoints = copy.copy(keypoints)
|
209
205
|
self.prevDescriptors = copy.copy(descriptors)
|
210
|
-
|
211
|
-
# Initialization done
|
212
206
|
self.initializedFirstFrame = True
|
213
|
-
|
214
207
|
return H
|
215
208
|
|
216
|
-
# Match descriptors
|
209
|
+
# Match descriptors between previous and current frame
|
217
210
|
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
218
211
|
|
219
|
-
# Filter matches based on
|
212
|
+
# Filter matches based on spatial distance constraints
|
220
213
|
matches = []
|
221
214
|
spatialDistances = []
|
222
|
-
|
223
215
|
maxSpatialDistance = 0.25 * np.array([width, height])
|
224
216
|
|
225
217
|
# Handle empty matches case
|
226
218
|
if len(knnMatches) == 0:
|
227
|
-
# Store to next iteration
|
228
219
|
self.prevFrame = frame.copy()
|
229
220
|
self.prevKeyPoints = copy.copy(keypoints)
|
230
221
|
self.prevDescriptors = copy.copy(descriptors)
|
231
|
-
|
232
222
|
return H
|
233
223
|
|
224
|
+
# Apply Lowe's ratio test and spatial distance filtering
|
234
225
|
for m, n in knnMatches:
|
235
226
|
if m.distance < 0.9 * n.distance:
|
236
227
|
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
@@ -247,11 +238,12 @@ class GMC:
|
|
247
238
|
spatialDistances.append(spatialDistance)
|
248
239
|
matches.append(m)
|
249
240
|
|
241
|
+
# Filter outliers using statistical analysis
|
250
242
|
meanSpatialDistances = np.mean(spatialDistances, 0)
|
251
243
|
stdSpatialDistances = np.std(spatialDistances, 0)
|
252
|
-
|
253
244
|
inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
254
245
|
|
246
|
+
# Extract good matches and corresponding points
|
255
247
|
goodMatches = []
|
256
248
|
prevPoints = []
|
257
249
|
currPoints = []
|
@@ -264,39 +256,18 @@ class GMC:
|
|
264
256
|
prevPoints = np.array(prevPoints)
|
265
257
|
currPoints = np.array(currPoints)
|
266
258
|
|
267
|
-
#
|
268
|
-
# if False:
|
269
|
-
# import matplotlib.pyplot as plt
|
270
|
-
# matches_img = np.hstack((self.prevFrame, frame))
|
271
|
-
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
272
|
-
# W = self.prevFrame.shape[1]
|
273
|
-
# for m in goodMatches:
|
274
|
-
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
275
|
-
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
276
|
-
# curr_pt[0] += W
|
277
|
-
# color = np.random.randint(0, 255, 3)
|
278
|
-
# color = (int(color[0]), int(color[1]), int(color[2]))
|
279
|
-
#
|
280
|
-
# matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
281
|
-
# matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
282
|
-
# matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
283
|
-
#
|
284
|
-
# plt.figure()
|
285
|
-
# plt.imshow(matches_img)
|
286
|
-
# plt.show()
|
287
|
-
|
288
|
-
# Find rigid matrix
|
259
|
+
# Estimate transformation matrix using RANSAC
|
289
260
|
if prevPoints.shape[0] > 4:
|
290
261
|
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
291
262
|
|
292
|
-
#
|
263
|
+
# Scale translation components back to original resolution
|
293
264
|
if self.downscale > 1.0:
|
294
265
|
H[0, 2] *= self.downscale
|
295
266
|
H[1, 2] *= self.downscale
|
296
267
|
else:
|
297
268
|
LOGGER.warning("not enough matching points")
|
298
269
|
|
299
|
-
# Store
|
270
|
+
# Store current frame data for next iteration
|
300
271
|
self.prevFrame = frame.copy()
|
301
272
|
self.prevKeyPoints = copy.copy(keypoints)
|
302
273
|
self.prevDescriptors = copy.copy(descriptors)
|
@@ -324,24 +295,24 @@ class GMC:
|
|
324
295
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
325
296
|
H = np.eye(2, 3)
|
326
297
|
|
327
|
-
# Downscale image
|
298
|
+
# Downscale image for computational efficiency
|
328
299
|
if self.downscale > 1.0:
|
329
300
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
330
301
|
|
331
|
-
# Find
|
302
|
+
# Find good features to track
|
332
303
|
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
333
304
|
|
334
|
-
# Handle first frame
|
305
|
+
# Handle first frame initialization
|
335
306
|
if not self.initializedFirstFrame or self.prevKeyPoints is None:
|
336
307
|
self.prevFrame = frame.copy()
|
337
308
|
self.prevKeyPoints = copy.copy(keypoints)
|
338
309
|
self.initializedFirstFrame = True
|
339
310
|
return H
|
340
311
|
|
341
|
-
#
|
312
|
+
# Calculate optical flow using Lucas-Kanade method
|
342
313
|
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
343
314
|
|
344
|
-
#
|
315
|
+
# Extract successfully tracked points
|
345
316
|
prevPoints = []
|
346
317
|
currPoints = []
|
347
318
|
|
@@ -353,16 +324,18 @@ class GMC:
|
|
353
324
|
prevPoints = np.array(prevPoints)
|
354
325
|
currPoints = np.array(currPoints)
|
355
326
|
|
356
|
-
#
|
327
|
+
# Estimate transformation matrix using RANSAC
|
357
328
|
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):
|
358
329
|
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
359
330
|
|
331
|
+
# Scale translation components back to original resolution
|
360
332
|
if self.downscale > 1.0:
|
361
333
|
H[0, 2] *= self.downscale
|
362
334
|
H[1, 2] *= self.downscale
|
363
335
|
else:
|
364
336
|
LOGGER.warning("not enough matching points")
|
365
337
|
|
338
|
+
# Store current frame data for next iteration
|
366
339
|
self.prevFrame = frame.copy()
|
367
340
|
self.prevKeyPoints = copy.copy(keypoints)
|
368
341
|
|