dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -2,13 +2,13 @@
|
|
|
2
2
|
"""Module defines the base classes and structures for object tracking in YOLO."""
|
|
3
3
|
|
|
4
4
|
from collections import OrderedDict
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class TrackState:
|
|
10
|
-
"""
|
|
11
|
-
Enumeration class representing the possible states of an object being tracked.
|
|
11
|
+
"""Enumeration class representing the possible states of an object being tracked.
|
|
12
12
|
|
|
13
13
|
Attributes:
|
|
14
14
|
New (int): State when the object is newly detected.
|
|
@@ -29,8 +29,7 @@ class TrackState:
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class BaseTrack:
|
|
32
|
-
"""
|
|
33
|
-
Base class for object tracking, providing foundational attributes and methods.
|
|
32
|
+
"""Base class for object tracking, providing foundational attributes and methods.
|
|
34
33
|
|
|
35
34
|
Attributes:
|
|
36
35
|
_count (int): Class-level counter for unique track IDs.
|
|
@@ -66,15 +65,7 @@ class BaseTrack:
|
|
|
66
65
|
_count = 0
|
|
67
66
|
|
|
68
67
|
def __init__(self):
|
|
69
|
-
"""
|
|
70
|
-
Initialize a new track with a unique ID and foundational tracking attributes.
|
|
71
|
-
|
|
72
|
-
Examples:
|
|
73
|
-
Initialize a new track
|
|
74
|
-
>>> track = BaseTrack()
|
|
75
|
-
>>> print(track.track_id)
|
|
76
|
-
0
|
|
77
|
-
"""
|
|
68
|
+
"""Initialize a new track with a unique ID and foundational tracking attributes."""
|
|
78
69
|
self.track_id = 0
|
|
79
70
|
self.is_activated = False
|
|
80
71
|
self.state = TrackState.New
|
|
@@ -88,37 +79,37 @@ class BaseTrack:
|
|
|
88
79
|
self.location = (np.inf, np.inf)
|
|
89
80
|
|
|
90
81
|
@property
|
|
91
|
-
def end_frame(self):
|
|
92
|
-
"""
|
|
82
|
+
def end_frame(self) -> int:
|
|
83
|
+
"""Return the ID of the most recent frame where the object was tracked."""
|
|
93
84
|
return self.frame_id
|
|
94
85
|
|
|
95
86
|
@staticmethod
|
|
96
|
-
def next_id():
|
|
87
|
+
def next_id() -> int:
|
|
97
88
|
"""Increment and return the next unique global track ID for object tracking."""
|
|
98
89
|
BaseTrack._count += 1
|
|
99
90
|
return BaseTrack._count
|
|
100
91
|
|
|
101
|
-
def activate(self, *args):
|
|
102
|
-
"""
|
|
92
|
+
def activate(self, *args: Any) -> None:
|
|
93
|
+
"""Activate the track with provided arguments, initializing necessary attributes for tracking."""
|
|
103
94
|
raise NotImplementedError
|
|
104
95
|
|
|
105
|
-
def predict(self):
|
|
106
|
-
"""
|
|
96
|
+
def predict(self) -> None:
|
|
97
|
+
"""Predict the next state of the track based on the current state and tracking model."""
|
|
107
98
|
raise NotImplementedError
|
|
108
99
|
|
|
109
|
-
def update(self, *args, **kwargs):
|
|
110
|
-
"""
|
|
100
|
+
def update(self, *args: Any, **kwargs: Any) -> None:
|
|
101
|
+
"""Update the track with new observations and data, modifying its state and attributes accordingly."""
|
|
111
102
|
raise NotImplementedError
|
|
112
103
|
|
|
113
|
-
def mark_lost(self):
|
|
114
|
-
"""
|
|
104
|
+
def mark_lost(self) -> None:
|
|
105
|
+
"""Mark the track as lost by updating its state to TrackState.Lost."""
|
|
115
106
|
self.state = TrackState.Lost
|
|
116
107
|
|
|
117
|
-
def mark_removed(self):
|
|
118
|
-
"""
|
|
108
|
+
def mark_removed(self) -> None:
|
|
109
|
+
"""Mark the track as removed by setting its state to TrackState.Removed."""
|
|
119
110
|
self.state = TrackState.Removed
|
|
120
111
|
|
|
121
112
|
@staticmethod
|
|
122
|
-
def reset_id():
|
|
113
|
+
def reset_id() -> None:
|
|
123
114
|
"""Reset the global track ID counter to its initial value."""
|
|
124
115
|
BaseTrack._count = 0
|
ultralytics/trackers/bot_sort.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from collections import deque
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import numpy as np
|
|
6
9
|
import torch
|
|
@@ -16,8 +19,7 @@ from .utils.kalman_filter import KalmanFilterXYWH
|
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class BOTrack(STrack):
|
|
19
|
-
"""
|
|
20
|
-
An extended version of the STrack class for YOLO, adding object tracking features.
|
|
22
|
+
"""An extended version of the STrack class for YOLO, adding object tracking features.
|
|
21
23
|
|
|
22
24
|
This class extends the STrack class to include additional functionalities for object tracking, such as feature
|
|
23
25
|
smoothing, Kalman filter prediction, and reactivation of tracks.
|
|
@@ -51,26 +53,27 @@ class BOTrack(STrack):
|
|
|
51
53
|
|
|
52
54
|
shared_kalman = KalmanFilterXYWH()
|
|
53
55
|
|
|
54
|
-
def __init__(
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
def __init__(
|
|
57
|
+
self, xywh: np.ndarray, score: float, cls: int, feat: np.ndarray | None = None, feat_history: int = 50
|
|
58
|
+
):
|
|
59
|
+
"""Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
|
57
60
|
|
|
58
61
|
Args:
|
|
59
|
-
|
|
62
|
+
xywh (np.ndarray): Bounding box coordinates in xywh format (center x, center y, width, height).
|
|
60
63
|
score (float): Confidence score of the detection.
|
|
61
64
|
cls (int): Class ID of the detected object.
|
|
62
|
-
feat (np.ndarray
|
|
65
|
+
feat (np.ndarray, optional): Feature vector associated with the detection.
|
|
63
66
|
feat_history (int): Maximum length of the feature history deque.
|
|
64
67
|
|
|
65
68
|
Examples:
|
|
66
69
|
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
|
|
67
|
-
>>>
|
|
70
|
+
>>> xywh = np.array([100, 150, 60, 50])
|
|
68
71
|
>>> score = 0.9
|
|
69
72
|
>>> cls = 1
|
|
70
73
|
>>> feat = np.random.rand(128)
|
|
71
|
-
>>> bo_track = BOTrack(
|
|
74
|
+
>>> bo_track = BOTrack(xywh, score, cls, feat)
|
|
72
75
|
"""
|
|
73
|
-
super().__init__(
|
|
76
|
+
super().__init__(xywh, score, cls)
|
|
74
77
|
|
|
75
78
|
self.smooth_feat = None
|
|
76
79
|
self.curr_feat = None
|
|
@@ -79,7 +82,7 @@ class BOTrack(STrack):
|
|
|
79
82
|
self.features = deque([], maxlen=feat_history)
|
|
80
83
|
self.alpha = 0.9
|
|
81
84
|
|
|
82
|
-
def update_features(self, feat):
|
|
85
|
+
def update_features(self, feat: np.ndarray) -> None:
|
|
83
86
|
"""Update the feature vector and apply exponential moving average smoothing."""
|
|
84
87
|
feat /= np.linalg.norm(feat)
|
|
85
88
|
self.curr_feat = feat
|
|
@@ -90,7 +93,7 @@ class BOTrack(STrack):
|
|
|
90
93
|
self.features.append(feat)
|
|
91
94
|
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
|
92
95
|
|
|
93
|
-
def predict(self):
|
|
96
|
+
def predict(self) -> None:
|
|
94
97
|
"""Predict the object's future state using the Kalman filter to update its mean and covariance."""
|
|
95
98
|
mean_state = self.mean.copy()
|
|
96
99
|
if self.state != TrackState.Tracked:
|
|
@@ -99,20 +102,20 @@ class BOTrack(STrack):
|
|
|
99
102
|
|
|
100
103
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
|
101
104
|
|
|
102
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
|
105
|
+
def re_activate(self, new_track: BOTrack, frame_id: int, new_id: bool = False) -> None:
|
|
103
106
|
"""Reactivate a track with updated features and optionally assign a new ID."""
|
|
104
107
|
if new_track.curr_feat is not None:
|
|
105
108
|
self.update_features(new_track.curr_feat)
|
|
106
109
|
super().re_activate(new_track, frame_id, new_id)
|
|
107
110
|
|
|
108
|
-
def update(self, new_track, frame_id):
|
|
111
|
+
def update(self, new_track: BOTrack, frame_id: int) -> None:
|
|
109
112
|
"""Update the track with new detection information and the current frame ID."""
|
|
110
113
|
if new_track.curr_feat is not None:
|
|
111
114
|
self.update_features(new_track.curr_feat)
|
|
112
115
|
super().update(new_track, frame_id)
|
|
113
116
|
|
|
114
117
|
@property
|
|
115
|
-
def tlwh(self):
|
|
118
|
+
def tlwh(self) -> np.ndarray:
|
|
116
119
|
"""Return the current bounding box position in `(top left x, top left y, width, height)` format."""
|
|
117
120
|
if self.mean is None:
|
|
118
121
|
return self._tlwh.copy()
|
|
@@ -121,7 +124,7 @@ class BOTrack(STrack):
|
|
|
121
124
|
return ret
|
|
122
125
|
|
|
123
126
|
@staticmethod
|
|
124
|
-
def multi_predict(stracks):
|
|
127
|
+
def multi_predict(stracks: list[BOTrack]) -> None:
|
|
125
128
|
"""Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
|
126
129
|
if len(stracks) <= 0:
|
|
127
130
|
return
|
|
@@ -136,12 +139,12 @@ class BOTrack(STrack):
|
|
|
136
139
|
stracks[i].mean = mean
|
|
137
140
|
stracks[i].covariance = cov
|
|
138
141
|
|
|
139
|
-
def convert_coords(self, tlwh):
|
|
142
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
|
140
143
|
"""Convert tlwh bounding box coordinates to xywh format."""
|
|
141
144
|
return self.tlwh_to_xywh(tlwh)
|
|
142
145
|
|
|
143
146
|
@staticmethod
|
|
144
|
-
def tlwh_to_xywh(tlwh):
|
|
147
|
+
def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
|
|
145
148
|
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
|
146
149
|
ret = np.asarray(tlwh).copy()
|
|
147
150
|
ret[:2] += ret[2:] / 2
|
|
@@ -149,8 +152,7 @@ class BOTrack(STrack):
|
|
|
149
152
|
|
|
150
153
|
|
|
151
154
|
class BOTSORT(BYTETracker):
|
|
152
|
-
"""
|
|
153
|
-
An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
|
|
155
|
+
"""An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
|
|
154
156
|
|
|
155
157
|
Attributes:
|
|
156
158
|
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
|
@@ -172,16 +174,15 @@ class BOTSORT(BYTETracker):
|
|
|
172
174
|
>>> bot_sort.init_track(dets, scores, cls, img)
|
|
173
175
|
>>> bot_sort.multi_predict(tracks)
|
|
174
176
|
|
|
175
|
-
|
|
177
|
+
Notes:
|
|
176
178
|
The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
|
|
177
179
|
"""
|
|
178
180
|
|
|
179
|
-
def __init__(self, args, frame_rate=30):
|
|
180
|
-
"""
|
|
181
|
-
Initialize BOTSORT object with ReID module and GMC algorithm.
|
|
181
|
+
def __init__(self, args: Any, frame_rate: int = 30):
|
|
182
|
+
"""Initialize BOTSORT object with ReID module and GMC algorithm.
|
|
182
183
|
|
|
183
184
|
Args:
|
|
184
|
-
args (
|
|
185
|
+
args (Any): Parsed command-line arguments containing tracking parameters.
|
|
185
186
|
frame_rate (int): Frame rate of the video being processed.
|
|
186
187
|
|
|
187
188
|
Examples:
|
|
@@ -203,21 +204,23 @@ class BOTSORT(BYTETracker):
|
|
|
203
204
|
else None
|
|
204
205
|
)
|
|
205
206
|
|
|
206
|
-
def get_kalmanfilter(self):
|
|
207
|
+
def get_kalmanfilter(self) -> KalmanFilterXYWH:
|
|
207
208
|
"""Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
|
208
209
|
return KalmanFilterXYWH()
|
|
209
210
|
|
|
210
|
-
def init_track(self,
|
|
211
|
+
def init_track(self, results, img: np.ndarray | None = None) -> list[BOTrack]:
|
|
211
212
|
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
|
212
|
-
if len(
|
|
213
|
+
if len(results) == 0:
|
|
213
214
|
return []
|
|
215
|
+
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
|
|
216
|
+
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
|
214
217
|
if self.args.with_reid and self.encoder is not None:
|
|
215
|
-
features_keep = self.encoder(img,
|
|
216
|
-
return [BOTrack(
|
|
218
|
+
features_keep = self.encoder(img, bboxes)
|
|
219
|
+
return [BOTrack(xywh, s, c, f) for (xywh, s, c, f) in zip(bboxes, results.conf, results.cls, features_keep)]
|
|
217
220
|
else:
|
|
218
|
-
return [BOTrack(
|
|
221
|
+
return [BOTrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
|
|
219
222
|
|
|
220
|
-
def get_dists(self, tracks, detections):
|
|
223
|
+
def get_dists(self, tracks: list[BOTrack], detections: list[BOTrack]) -> np.ndarray:
|
|
221
224
|
"""Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
|
|
222
225
|
dists = matching.iou_distance(tracks, detections)
|
|
223
226
|
dists_mask = dists > (1 - self.proximity_thresh)
|
|
@@ -232,11 +235,11 @@ class BOTSORT(BYTETracker):
|
|
|
232
235
|
dists = np.minimum(dists, emb_dists)
|
|
233
236
|
return dists
|
|
234
237
|
|
|
235
|
-
def multi_predict(self, tracks):
|
|
238
|
+
def multi_predict(self, tracks: list[BOTrack]) -> None:
|
|
236
239
|
"""Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
|
237
240
|
BOTrack.multi_predict(tracks)
|
|
238
241
|
|
|
239
|
-
def reset(self):
|
|
242
|
+
def reset(self) -> None:
|
|
240
243
|
"""Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
|
241
244
|
super().reset()
|
|
242
245
|
self.gmc.reset_params()
|
|
@@ -245,16 +248,22 @@ class BOTSORT(BYTETracker):
|
|
|
245
248
|
class ReID:
|
|
246
249
|
"""YOLO model as encoder for re-identification."""
|
|
247
250
|
|
|
248
|
-
def __init__(self, model):
|
|
249
|
-
"""Initialize encoder for re-identification.
|
|
251
|
+
def __init__(self, model: str):
|
|
252
|
+
"""Initialize encoder for re-identification.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model (str): Path to the YOLO model for re-identification.
|
|
256
|
+
"""
|
|
250
257
|
from ultralytics import YOLO
|
|
251
258
|
|
|
252
259
|
self.model = YOLO(model)
|
|
253
|
-
self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) #
|
|
260
|
+
self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
|
|
254
261
|
|
|
255
|
-
def __call__(self, img, dets):
|
|
262
|
+
def __call__(self, img: np.ndarray, dets: np.ndarray) -> list[np.ndarray]:
|
|
256
263
|
"""Extract embeddings for detected objects."""
|
|
257
|
-
feats = self.model
|
|
264
|
+
feats = self.model.predictor(
|
|
265
|
+
[save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
|
|
266
|
+
)
|
|
258
267
|
if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
|
|
259
268
|
feats = feats[0] # batched prediction with non-PyTorch backend
|
|
260
269
|
return [f.cpu().numpy() for f in feats]
|