dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
- dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import io
|
4
|
-
from typing import Any
|
4
|
+
from typing import Any, List
|
5
5
|
|
6
6
|
import cv2
|
7
7
|
|
@@ -24,7 +24,7 @@ class Inference:
|
|
24
24
|
model_path (str): Path to the loaded model.
|
25
25
|
model (YOLO): The YOLO model instance.
|
26
26
|
source (str): Selected video source (webcam or video file).
|
27
|
-
enable_trk (
|
27
|
+
enable_trk (bool): Enable tracking option.
|
28
28
|
conf (float): Confidence threshold for detection.
|
29
29
|
iou (float): IoU threshold for non-maximum suppression.
|
30
30
|
org_frame (Any): Container for the original frame to be displayed.
|
@@ -33,14 +33,19 @@ class Inference:
|
|
33
33
|
selected_ind (List[int]): List of selected class indices for detection.
|
34
34
|
|
35
35
|
Methods:
|
36
|
-
web_ui:
|
37
|
-
sidebar:
|
38
|
-
source_upload:
|
39
|
-
configure:
|
40
|
-
inference:
|
36
|
+
web_ui: Set up the Streamlit web interface with custom HTML elements.
|
37
|
+
sidebar: Configure the Streamlit sidebar for model and inference settings.
|
38
|
+
source_upload: Handle video file uploads through the Streamlit interface.
|
39
|
+
configure: Configure the model and load selected classes for inference.
|
40
|
+
inference: Perform real-time object detection inference.
|
41
41
|
|
42
42
|
Examples:
|
43
|
-
|
43
|
+
Create an Inference instance with a custom model
|
44
|
+
>>> inf = Inference(model="path/to/model.pt")
|
45
|
+
>>> inf.inference()
|
46
|
+
|
47
|
+
Create an Inference instance with default settings
|
48
|
+
>>> inf = Inference()
|
44
49
|
>>> inf.inference()
|
45
50
|
"""
|
46
51
|
|
@@ -62,7 +67,7 @@ class Inference:
|
|
62
67
|
self.org_frame = None # Container for the original frame display
|
63
68
|
self.ann_frame = None # Container for the annotated frame display
|
64
69
|
self.vid_file_name = None # Video file name or webcam index
|
65
|
-
self.selected_ind = [] # List of selected class indices for detection
|
70
|
+
self.selected_ind: List[int] = [] # List of selected class indices for detection
|
66
71
|
self.model = None # YOLO model instance
|
67
72
|
|
68
73
|
self.temp_dict = {"model": None, **kwargs}
|
@@ -73,7 +78,7 @@ class Inference:
|
|
73
78
|
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
74
79
|
|
75
80
|
def web_ui(self):
|
76
|
-
"""
|
81
|
+
"""Set up the Streamlit web interface with custom HTML elements."""
|
77
82
|
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
78
83
|
|
79
84
|
# Main title of streamlit application
|
@@ -102,7 +107,7 @@ class Inference:
|
|
102
107
|
"Video",
|
103
108
|
("webcam", "video"),
|
104
109
|
) # Add source selection dropdown
|
105
|
-
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
|
110
|
+
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
|
106
111
|
self.conf = float(
|
107
112
|
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
108
113
|
) # Slider for confidence
|
@@ -166,7 +171,7 @@ class Inference:
|
|
166
171
|
break
|
167
172
|
|
168
173
|
# Process frame with model
|
169
|
-
if self.enable_trk
|
174
|
+
if self.enable_trk:
|
170
175
|
results = self.model.track(
|
171
176
|
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
172
177
|
)
|
@@ -23,9 +23,9 @@ class TrackZone(BaseSolution):
|
|
23
23
|
clss (List[int]): Class indices of tracked objects.
|
24
24
|
|
25
25
|
Methods:
|
26
|
-
process:
|
27
|
-
extract_tracks:
|
28
|
-
display_output:
|
26
|
+
process: Process each frame of the video, applying region-based tracking.
|
27
|
+
extract_tracks: Extract tracking information from the input frame.
|
28
|
+
display_output: Display the processed output.
|
29
29
|
|
30
30
|
Examples:
|
31
31
|
>>> tracker = TrackZone()
|
@@ -82,7 +82,7 @@ class TrackZone(BaseSolution):
|
|
82
82
|
)
|
83
83
|
|
84
84
|
plot_im = annotator.result()
|
85
|
-
self.display_output(plot_im) #
|
85
|
+
self.display_output(plot_im) # Display output with base class function
|
86
86
|
|
87
87
|
# Return a SolutionResults
|
88
88
|
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
@@ -2,6 +2,7 @@
|
|
2
2
|
"""Module defines the base classes and structures for object tracking in YOLO."""
|
3
3
|
|
4
4
|
from collections import OrderedDict
|
5
|
+
from typing import Any
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
|
@@ -66,15 +67,7 @@ class BaseTrack:
|
|
66
67
|
_count = 0
|
67
68
|
|
68
69
|
def __init__(self):
|
69
|
-
"""
|
70
|
-
Initialize a new track with a unique ID and foundational tracking attributes.
|
71
|
-
|
72
|
-
Examples:
|
73
|
-
Initialize a new track
|
74
|
-
>>> track = BaseTrack()
|
75
|
-
>>> print(track.track_id)
|
76
|
-
0
|
77
|
-
"""
|
70
|
+
"""Initialize a new track with a unique ID and foundational tracking attributes."""
|
78
71
|
self.track_id = 0
|
79
72
|
self.is_activated = False
|
80
73
|
self.state = TrackState.New
|
@@ -88,37 +81,37 @@ class BaseTrack:
|
|
88
81
|
self.location = (np.inf, np.inf)
|
89
82
|
|
90
83
|
@property
|
91
|
-
def end_frame(self):
|
92
|
-
"""
|
84
|
+
def end_frame(self) -> int:
|
85
|
+
"""Return the ID of the most recent frame where the object was tracked."""
|
93
86
|
return self.frame_id
|
94
87
|
|
95
88
|
@staticmethod
|
96
|
-
def next_id():
|
89
|
+
def next_id() -> int:
|
97
90
|
"""Increment and return the next unique global track ID for object tracking."""
|
98
91
|
BaseTrack._count += 1
|
99
92
|
return BaseTrack._count
|
100
93
|
|
101
|
-
def activate(self, *args):
|
102
|
-
"""
|
94
|
+
def activate(self, *args: Any) -> None:
|
95
|
+
"""Activate the track with provided arguments, initializing necessary attributes for tracking."""
|
103
96
|
raise NotImplementedError
|
104
97
|
|
105
|
-
def predict(self):
|
106
|
-
"""
|
98
|
+
def predict(self) -> None:
|
99
|
+
"""Predict the next state of the track based on the current state and tracking model."""
|
107
100
|
raise NotImplementedError
|
108
101
|
|
109
|
-
def update(self, *args, **kwargs):
|
110
|
-
"""
|
102
|
+
def update(self, *args: Any, **kwargs: Any) -> None:
|
103
|
+
"""Update the track with new observations and data, modifying its state and attributes accordingly."""
|
111
104
|
raise NotImplementedError
|
112
105
|
|
113
|
-
def mark_lost(self):
|
114
|
-
"""
|
106
|
+
def mark_lost(self) -> None:
|
107
|
+
"""Mark the track as lost by updating its state to TrackState.Lost."""
|
115
108
|
self.state = TrackState.Lost
|
116
109
|
|
117
|
-
def mark_removed(self):
|
118
|
-
"""
|
110
|
+
def mark_removed(self) -> None:
|
111
|
+
"""Mark the track as removed by setting its state to TrackState.Removed."""
|
119
112
|
self.state = TrackState.Removed
|
120
113
|
|
121
114
|
@staticmethod
|
122
|
-
def reset_id():
|
115
|
+
def reset_id() -> None:
|
123
116
|
"""Reset the global track ID counter to its initial value."""
|
124
117
|
BaseTrack._count = 0
|
ultralytics/trackers/bot_sort.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from collections import deque
|
4
|
+
from typing import Any, List, Optional
|
4
5
|
|
5
6
|
import numpy as np
|
6
7
|
import torch
|
@@ -51,7 +52,9 @@ class BOTrack(STrack):
|
|
51
52
|
|
52
53
|
shared_kalman = KalmanFilterXYWH()
|
53
54
|
|
54
|
-
def __init__(
|
55
|
+
def __init__(
|
56
|
+
self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50
|
57
|
+
):
|
55
58
|
"""
|
56
59
|
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
57
60
|
|
@@ -59,7 +62,7 @@ class BOTrack(STrack):
|
|
59
62
|
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
|
60
63
|
score (float): Confidence score of the detection.
|
61
64
|
cls (int): Class ID of the detected object.
|
62
|
-
feat (np.ndarray
|
65
|
+
feat (np.ndarray, optional): Feature vector associated with the detection.
|
63
66
|
feat_history (int): Maximum length of the feature history deque.
|
64
67
|
|
65
68
|
Examples:
|
@@ -79,7 +82,7 @@ class BOTrack(STrack):
|
|
79
82
|
self.features = deque([], maxlen=feat_history)
|
80
83
|
self.alpha = 0.9
|
81
84
|
|
82
|
-
def update_features(self, feat):
|
85
|
+
def update_features(self, feat: np.ndarray) -> None:
|
83
86
|
"""Update the feature vector and apply exponential moving average smoothing."""
|
84
87
|
feat /= np.linalg.norm(feat)
|
85
88
|
self.curr_feat = feat
|
@@ -90,7 +93,7 @@ class BOTrack(STrack):
|
|
90
93
|
self.features.append(feat)
|
91
94
|
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
92
95
|
|
93
|
-
def predict(self):
|
96
|
+
def predict(self) -> None:
|
94
97
|
"""Predict the object's future state using the Kalman filter to update its mean and covariance."""
|
95
98
|
mean_state = self.mean.copy()
|
96
99
|
if self.state != TrackState.Tracked:
|
@@ -99,20 +102,20 @@ class BOTrack(STrack):
|
|
99
102
|
|
100
103
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
101
104
|
|
102
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
105
|
+
def re_activate(self, new_track: "BOTrack", frame_id: int, new_id: bool = False) -> None:
|
103
106
|
"""Reactivate a track with updated features and optionally assign a new ID."""
|
104
107
|
if new_track.curr_feat is not None:
|
105
108
|
self.update_features(new_track.curr_feat)
|
106
109
|
super().re_activate(new_track, frame_id, new_id)
|
107
110
|
|
108
|
-
def update(self, new_track, frame_id):
|
111
|
+
def update(self, new_track: "BOTrack", frame_id: int) -> None:
|
109
112
|
"""Update the track with new detection information and the current frame ID."""
|
110
113
|
if new_track.curr_feat is not None:
|
111
114
|
self.update_features(new_track.curr_feat)
|
112
115
|
super().update(new_track, frame_id)
|
113
116
|
|
114
117
|
@property
|
115
|
-
def tlwh(self):
|
118
|
+
def tlwh(self) -> np.ndarray:
|
116
119
|
"""Return the current bounding box position in `(top left x, top left y, width, height)` format."""
|
117
120
|
if self.mean is None:
|
118
121
|
return self._tlwh.copy()
|
@@ -121,7 +124,7 @@ class BOTrack(STrack):
|
|
121
124
|
return ret
|
122
125
|
|
123
126
|
@staticmethod
|
124
|
-
def multi_predict(stracks):
|
127
|
+
def multi_predict(stracks: List["BOTrack"]) -> None:
|
125
128
|
"""Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
126
129
|
if len(stracks) <= 0:
|
127
130
|
return
|
@@ -136,12 +139,12 @@ class BOTrack(STrack):
|
|
136
139
|
stracks[i].mean = mean
|
137
140
|
stracks[i].covariance = cov
|
138
141
|
|
139
|
-
def convert_coords(self, tlwh):
|
142
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
140
143
|
"""Convert tlwh bounding box coordinates to xywh format."""
|
141
144
|
return self.tlwh_to_xywh(tlwh)
|
142
145
|
|
143
146
|
@staticmethod
|
144
|
-
def tlwh_to_xywh(tlwh):
|
147
|
+
def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
|
145
148
|
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
146
149
|
ret = np.asarray(tlwh).copy()
|
147
150
|
ret[:2] += ret[2:] / 2
|
@@ -176,12 +179,12 @@ class BOTSORT(BYTETracker):
|
|
176
179
|
The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
|
177
180
|
"""
|
178
181
|
|
179
|
-
def __init__(self, args, frame_rate=30):
|
182
|
+
def __init__(self, args: Any, frame_rate: int = 30):
|
180
183
|
"""
|
181
184
|
Initialize BOTSORT object with ReID module and GMC algorithm.
|
182
185
|
|
183
186
|
Args:
|
184
|
-
args (
|
187
|
+
args (Any): Parsed command-line arguments containing tracking parameters.
|
185
188
|
frame_rate (int): Frame rate of the video being processed.
|
186
189
|
|
187
190
|
Examples:
|
@@ -203,11 +206,13 @@ class BOTSORT(BYTETracker):
|
|
203
206
|
else None
|
204
207
|
)
|
205
208
|
|
206
|
-
def get_kalmanfilter(self):
|
209
|
+
def get_kalmanfilter(self) -> KalmanFilterXYWH:
|
207
210
|
"""Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
208
211
|
return KalmanFilterXYWH()
|
209
212
|
|
210
|
-
def init_track(
|
213
|
+
def init_track(
|
214
|
+
self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
|
215
|
+
) -> List[BOTrack]:
|
211
216
|
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
212
217
|
if len(dets) == 0:
|
213
218
|
return []
|
@@ -217,7 +222,7 @@ class BOTSORT(BYTETracker):
|
|
217
222
|
else:
|
218
223
|
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
219
224
|
|
220
|
-
def get_dists(self, tracks, detections):
|
225
|
+
def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:
|
221
226
|
"""Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
|
222
227
|
dists = matching.iou_distance(tracks, detections)
|
223
228
|
dists_mask = dists > (1 - self.proximity_thresh)
|
@@ -232,11 +237,11 @@ class BOTSORT(BYTETracker):
|
|
232
237
|
dists = np.minimum(dists, emb_dists)
|
233
238
|
return dists
|
234
239
|
|
235
|
-
def multi_predict(self, tracks):
|
240
|
+
def multi_predict(self, tracks: List[BOTrack]) -> None:
|
236
241
|
"""Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
237
242
|
BOTrack.multi_predict(tracks)
|
238
243
|
|
239
|
-
def reset(self):
|
244
|
+
def reset(self) -> None:
|
240
245
|
"""Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
241
246
|
super().reset()
|
242
247
|
self.gmc.reset_params()
|
@@ -245,14 +250,19 @@ class BOTSORT(BYTETracker):
|
|
245
250
|
class ReID:
|
246
251
|
"""YOLO model as encoder for re-identification."""
|
247
252
|
|
248
|
-
def __init__(self, model):
|
249
|
-
"""
|
253
|
+
def __init__(self, model: str):
|
254
|
+
"""
|
255
|
+
Initialize encoder for re-identification.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model (str): Path to the YOLO model for re-identification.
|
259
|
+
"""
|
250
260
|
from ultralytics import YOLO
|
251
261
|
|
252
262
|
self.model = YOLO(model)
|
253
263
|
self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
|
254
264
|
|
255
|
-
def __call__(self, img, dets):
|
265
|
+
def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
|
256
266
|
"""Extract embeddings for detected objects."""
|
257
267
|
feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
|
258
268
|
if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from typing import Any, List, Optional, Tuple
|
4
|
+
|
3
5
|
import numpy as np
|
4
6
|
|
5
7
|
from ..utils import LOGGER
|
@@ -29,16 +31,17 @@ class STrack(BaseTrack):
|
|
29
31
|
idx (int): Index or identifier for the object.
|
30
32
|
frame_id (int): Current frame ID.
|
31
33
|
start_frame (int): Frame where the object was first detected.
|
34
|
+
angle (float | None): Optional angle information for oriented bounding boxes.
|
32
35
|
|
33
36
|
Methods:
|
34
|
-
predict
|
35
|
-
multi_predict
|
36
|
-
multi_gmc
|
37
|
-
activate
|
38
|
-
re_activate
|
39
|
-
update
|
40
|
-
convert_coords
|
41
|
-
tlwh_to_xyah
|
37
|
+
predict: Predict the next state of the object using Kalman filter.
|
38
|
+
multi_predict: Predict the next states for multiple tracks.
|
39
|
+
multi_gmc: Update multiple track states using a homography matrix.
|
40
|
+
activate: Activate a new tracklet.
|
41
|
+
re_activate: Reactivate a previously lost tracklet.
|
42
|
+
update: Update the state of a matched track.
|
43
|
+
convert_coords: Convert bounding box to x-y-aspect-height format.
|
44
|
+
tlwh_to_xyah: Convert tlwh bounding box to xyah format.
|
42
45
|
|
43
46
|
Examples:
|
44
47
|
Initialize and activate a new track
|
@@ -48,7 +51,7 @@ class STrack(BaseTrack):
|
|
48
51
|
|
49
52
|
shared_kalman = KalmanFilterXYAH()
|
50
53
|
|
51
|
-
def __init__(self, xywh, score, cls):
|
54
|
+
def __init__(self, xywh: List[float], score: float, cls: Any):
|
52
55
|
"""
|
53
56
|
Initialize a new STrack instance.
|
54
57
|
|
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
|
|
79
82
|
self.angle = xywh[4] if len(xywh) == 6 else None
|
80
83
|
|
81
84
|
def predict(self):
|
82
|
-
"""
|
85
|
+
"""Predict the next state (mean and covariance) of the object using the Kalman filter."""
|
83
86
|
mean_state = self.mean.copy()
|
84
87
|
if self.state != TrackState.Tracked:
|
85
88
|
mean_state[7] = 0
|
86
89
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
87
90
|
|
88
91
|
@staticmethod
|
89
|
-
def multi_predict(stracks):
|
92
|
+
def multi_predict(stracks: List["STrack"]):
|
90
93
|
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
91
94
|
if len(stracks) <= 0:
|
92
95
|
return
|
@@ -101,7 +104,7 @@ class STrack(BaseTrack):
|
|
101
104
|
stracks[i].covariance = cov
|
102
105
|
|
103
106
|
@staticmethod
|
104
|
-
def multi_gmc(stracks, H=np.eye(2, 3)):
|
107
|
+
def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
|
105
108
|
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
106
109
|
if len(stracks) > 0:
|
107
110
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
|
|
119
122
|
stracks[i].mean = mean
|
120
123
|
stracks[i].covariance = cov
|
121
124
|
|
122
|
-
def activate(self, kalman_filter, frame_id):
|
125
|
+
def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
|
123
126
|
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
124
127
|
self.kalman_filter = kalman_filter
|
125
128
|
self.track_id = self.next_id()
|
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
|
|
132
135
|
self.frame_id = frame_id
|
133
136
|
self.start_frame = frame_id
|
134
137
|
|
135
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
136
|
-
"""
|
138
|
+
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
|
139
|
+
"""Reactivate a previously lost track using new detection data and update its state and attributes."""
|
137
140
|
self.mean, self.covariance = self.kalman_filter.update(
|
138
141
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
139
142
|
)
|
@@ -148,7 +151,7 @@ class STrack(BaseTrack):
|
|
148
151
|
self.angle = new_track.angle
|
149
152
|
self.idx = new_track.idx
|
150
153
|
|
151
|
-
def update(self, new_track, frame_id):
|
154
|
+
def update(self, new_track: "STrack", frame_id: int):
|
152
155
|
"""
|
153
156
|
Update the state of a matched track.
|
154
157
|
|
@@ -177,13 +180,13 @@ class STrack(BaseTrack):
|
|
177
180
|
self.angle = new_track.angle
|
178
181
|
self.idx = new_track.idx
|
179
182
|
|
180
|
-
def convert_coords(self, tlwh):
|
183
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
181
184
|
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
182
185
|
return self.tlwh_to_xyah(tlwh)
|
183
186
|
|
184
187
|
@property
|
185
|
-
def tlwh(self):
|
186
|
-
"""
|
188
|
+
def tlwh(self) -> np.ndarray:
|
189
|
+
"""Get the bounding box in top-left-width-height format from the current state estimate."""
|
187
190
|
if self.mean is None:
|
188
191
|
return self._tlwh.copy()
|
189
192
|
ret = self.mean[:4].copy()
|
@@ -192,14 +195,14 @@ class STrack(BaseTrack):
|
|
192
195
|
return ret
|
193
196
|
|
194
197
|
@property
|
195
|
-
def xyxy(self):
|
196
|
-
"""
|
198
|
+
def xyxy(self) -> np.ndarray:
|
199
|
+
"""Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
197
200
|
ret = self.tlwh.copy()
|
198
201
|
ret[2:] += ret[:2]
|
199
202
|
return ret
|
200
203
|
|
201
204
|
@staticmethod
|
202
|
-
def tlwh_to_xyah(tlwh):
|
205
|
+
def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
|
203
206
|
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
204
207
|
ret = np.asarray(tlwh).copy()
|
205
208
|
ret[:2] += ret[2:] / 2
|
@@ -207,28 +210,28 @@ class STrack(BaseTrack):
|
|
207
210
|
return ret
|
208
211
|
|
209
212
|
@property
|
210
|
-
def xywh(self):
|
211
|
-
"""
|
213
|
+
def xywh(self) -> np.ndarray:
|
214
|
+
"""Get the current position of the bounding box in (center x, center y, width, height) format."""
|
212
215
|
ret = np.asarray(self.tlwh).copy()
|
213
216
|
ret[:2] += ret[2:] / 2
|
214
217
|
return ret
|
215
218
|
|
216
219
|
@property
|
217
|
-
def xywha(self):
|
218
|
-
"""
|
220
|
+
def xywha(self) -> np.ndarray:
|
221
|
+
"""Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
219
222
|
if self.angle is None:
|
220
223
|
LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
|
221
224
|
return self.xywh
|
222
225
|
return np.concatenate([self.xywh, self.angle[None]])
|
223
226
|
|
224
227
|
@property
|
225
|
-
def result(self):
|
226
|
-
"""
|
228
|
+
def result(self) -> List[float]:
|
229
|
+
"""Get the current tracking results in the appropriate bounding box format."""
|
227
230
|
coords = self.xyxy if self.angle is None else self.xywha
|
228
231
|
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
229
232
|
|
230
|
-
def __repr__(self):
|
231
|
-
"""
|
233
|
+
def __repr__(self) -> str:
|
234
|
+
"""Return a string representation of the STrack object including start frame, end frame, and track ID."""
|
232
235
|
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
233
236
|
|
234
237
|
|
@@ -250,15 +253,16 @@ class BYTETracker:
|
|
250
253
|
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
251
254
|
|
252
255
|
Methods:
|
253
|
-
update
|
254
|
-
get_kalmanfilter
|
255
|
-
init_track
|
256
|
-
get_dists
|
257
|
-
multi_predict
|
258
|
-
reset_id
|
259
|
-
|
260
|
-
|
261
|
-
|
256
|
+
update: Update object tracker with new detections.
|
257
|
+
get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
|
258
|
+
init_track: Initialize object tracking with detections.
|
259
|
+
get_dists: Calculate the distance between tracks and detections.
|
260
|
+
multi_predict: Predict the location of tracks.
|
261
|
+
reset_id: Reset the ID counter of STrack.
|
262
|
+
reset: Reset the tracker by clearing all tracks.
|
263
|
+
joint_stracks: Combine two lists of stracks.
|
264
|
+
sub_stracks: Filter out the stracks present in the second list from the first list.
|
265
|
+
remove_duplicate_stracks: Remove duplicate stracks based on IoU.
|
262
266
|
|
263
267
|
Examples:
|
264
268
|
Initialize BYTETracker and update with detection results
|
@@ -267,7 +271,7 @@ class BYTETracker:
|
|
267
271
|
>>> tracked_objects = tracker.update(results)
|
268
272
|
"""
|
269
273
|
|
270
|
-
def __init__(self, args, frame_rate=30):
|
274
|
+
def __init__(self, args, frame_rate: int = 30):
|
271
275
|
"""
|
272
276
|
Initialize a BYTETracker instance for object tracking.
|
273
277
|
|
@@ -280,9 +284,9 @@ class BYTETracker:
|
|
280
284
|
>>> args = Namespace(track_buffer=30)
|
281
285
|
>>> tracker = BYTETracker(args, frame_rate=30)
|
282
286
|
"""
|
283
|
-
self.tracked_stracks = [] # type:
|
284
|
-
self.lost_stracks = [] # type:
|
285
|
-
self.removed_stracks = [] # type:
|
287
|
+
self.tracked_stracks = [] # type: List[STrack]
|
288
|
+
self.lost_stracks = [] # type: List[STrack]
|
289
|
+
self.removed_stracks = [] # type: List[STrack]
|
286
290
|
|
287
291
|
self.frame_id = 0
|
288
292
|
self.args = args
|
@@ -290,8 +294,8 @@ class BYTETracker:
|
|
290
294
|
self.kalman_filter = self.get_kalmanfilter()
|
291
295
|
self.reset_id()
|
292
296
|
|
293
|
-
def update(self, results, img=None, feats=None):
|
294
|
-
"""
|
297
|
+
def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:
|
298
|
+
"""Update the tracker with new detections and return the current list of tracked objects."""
|
295
299
|
self.frame_id += 1
|
296
300
|
activated_stracks = []
|
297
301
|
refind_stracks = []
|
@@ -319,7 +323,7 @@ class BYTETracker:
|
|
319
323
|
detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)
|
320
324
|
# Add newly detected tracklets to tracked_stracks
|
321
325
|
unconfirmed = []
|
322
|
-
tracked_stracks = [] # type:
|
326
|
+
tracked_stracks = [] # type: List[STrack]
|
323
327
|
for track in self.tracked_stracks:
|
324
328
|
if not track.is_activated:
|
325
329
|
unconfirmed.append(track)
|
@@ -408,42 +412,44 @@ class BYTETracker:
|
|
408
412
|
|
409
413
|
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
410
414
|
|
411
|
-
def get_kalmanfilter(self):
|
412
|
-
"""
|
415
|
+
def get_kalmanfilter(self) -> KalmanFilterXYAH:
|
416
|
+
"""Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
413
417
|
return KalmanFilterXYAH()
|
414
418
|
|
415
|
-
def init_track(
|
416
|
-
|
419
|
+
def init_track(
|
420
|
+
self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
|
421
|
+
) -> List[STrack]:
|
422
|
+
"""Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
417
423
|
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
418
424
|
|
419
|
-
def get_dists(self, tracks, detections):
|
420
|
-
"""
|
425
|
+
def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:
|
426
|
+
"""Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
|
421
427
|
dists = matching.iou_distance(tracks, detections)
|
422
428
|
if self.args.fuse_score:
|
423
429
|
dists = matching.fuse_score(dists, detections)
|
424
430
|
return dists
|
425
431
|
|
426
|
-
def multi_predict(self, tracks):
|
432
|
+
def multi_predict(self, tracks: List[STrack]):
|
427
433
|
"""Predict the next states for multiple tracks using Kalman filter."""
|
428
434
|
STrack.multi_predict(tracks)
|
429
435
|
|
430
436
|
@staticmethod
|
431
437
|
def reset_id():
|
432
|
-
"""
|
438
|
+
"""Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
433
439
|
STrack.reset_id()
|
434
440
|
|
435
441
|
def reset(self):
|
436
|
-
"""
|
437
|
-
self.tracked_stracks = [] # type:
|
438
|
-
self.lost_stracks = [] # type:
|
439
|
-
self.removed_stracks = [] # type:
|
442
|
+
"""Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
443
|
+
self.tracked_stracks = [] # type: List[STrack]
|
444
|
+
self.lost_stracks = [] # type: List[STrack]
|
445
|
+
self.removed_stracks = [] # type: List[STrack]
|
440
446
|
self.frame_id = 0
|
441
447
|
self.kalman_filter = self.get_kalmanfilter()
|
442
448
|
self.reset_id()
|
443
449
|
|
444
450
|
@staticmethod
|
445
|
-
def joint_stracks(tlista, tlistb):
|
446
|
-
"""
|
451
|
+
def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
|
452
|
+
"""Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
447
453
|
exists = {}
|
448
454
|
res = []
|
449
455
|
for t in tlista:
|
@@ -457,14 +463,14 @@ class BYTETracker:
|
|
457
463
|
return res
|
458
464
|
|
459
465
|
@staticmethod
|
460
|
-
def sub_stracks(tlista, tlistb):
|
461
|
-
"""
|
466
|
+
def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
|
467
|
+
"""Filter out the stracks present in the second list from the first list."""
|
462
468
|
track_ids_b = {t.track_id for t in tlistb}
|
463
469
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
464
470
|
|
465
471
|
@staticmethod
|
466
|
-
def remove_duplicate_stracks(stracksa, stracksb):
|
467
|
-
"""
|
472
|
+
def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:
|
473
|
+
"""Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
468
474
|
pdist = matching.iou_distance(stracksa, stracksb)
|
469
475
|
pairs = np.where(pdist < 0.15)
|
470
476
|
dupa, dupb = [], []
|