ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +16 -8
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.142.dist-info/RECORD +0 -272
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -12,24 +12,29 @@ class SpeedEstimator(BaseSolution):
|
|
12
12
|
A class to estimate the speed of objects in a real-time video stream based on their tracks.
|
13
13
|
|
14
14
|
This class extends the BaseSolution class and provides functionality for estimating object speeds using
|
15
|
-
tracking data in video streams.
|
15
|
+
tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
|
16
|
+
to real-world units using a configurable meters-per-pixel scale factor.
|
16
17
|
|
17
18
|
Attributes:
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
19
|
+
fps (float): Video frame rate for time calculations.
|
20
|
+
frame_count (int): Global frame counter for tracking temporal information.
|
21
|
+
trk_frame_ids (dict): Maps track IDs to their first frame index.
|
22
|
+
spd (dict): Final speed per object in km/h once locked.
|
23
|
+
trk_hist (dict): Maps track IDs to deque of position history.
|
24
|
+
locked_ids (set): Track IDs whose speed has been finalized.
|
25
|
+
max_hist (int): Required frame history before computing speed.
|
26
|
+
meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
|
27
|
+
max_speed (int): Maximum allowed object speed; values above this will be capped.
|
23
28
|
|
24
29
|
Methods:
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
display_output: Displays the output with annotations.
|
30
|
+
process: Process input frames to estimate object speeds based on tracking data.
|
31
|
+
store_tracking_history: Store the tracking history for an object.
|
32
|
+
extract_tracks: Extract tracks from the current frame.
|
33
|
+
display_output: Display the output with annotations.
|
30
34
|
|
31
35
|
Examples:
|
32
|
-
|
36
|
+
Initialize speed estimator and process a frame
|
37
|
+
>>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
|
33
38
|
>>> frame = cv2.imread("frame.jpg")
|
34
39
|
>>> results = estimator.process(frame)
|
35
40
|
>>> cv2.imshow("Speed Estimation", results.plot_im)
|
@@ -44,15 +49,15 @@ class SpeedEstimator(BaseSolution):
|
|
44
49
|
"""
|
45
50
|
super().__init__(**kwargs)
|
46
51
|
|
47
|
-
self.fps = self.CFG["fps"] #
|
48
|
-
self.frame_count = 0 #
|
52
|
+
self.fps = self.CFG["fps"] # Video frame rate for time calculations
|
53
|
+
self.frame_count = 0 # Global frame counter
|
49
54
|
self.trk_frame_ids = {} # Track ID → first frame index
|
50
55
|
self.spd = {} # Final speed per object (km/h), once locked
|
51
56
|
self.trk_hist = {} # Track ID → deque of (time, position)
|
52
57
|
self.locked_ids = set() # Track IDs whose speed has been finalized
|
53
58
|
self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
|
54
59
|
self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
|
55
|
-
self.max_speed = self.CFG["max_speed"] #
|
60
|
+
self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
|
56
61
|
|
57
62
|
def process(self, im0):
|
58
63
|
"""
|
@@ -65,6 +70,7 @@ class SpeedEstimator(BaseSolution):
|
|
65
70
|
(SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
|
66
71
|
|
67
72
|
Examples:
|
73
|
+
Process a frame for speed estimation
|
68
74
|
>>> estimator = SpeedEstimator()
|
69
75
|
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
70
76
|
>>> results = estimator.process(image)
|
@@ -89,15 +95,15 @@ class SpeedEstimator(BaseSolution):
|
|
89
95
|
p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
|
90
96
|
dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
|
91
97
|
if dt > 0:
|
92
|
-
dx, dy = p1[0] - p0[0], p1[1] - p0[1] #
|
93
|
-
pixel_distance = sqrt(dx * dx + dy * dy) #
|
94
|
-
meters = pixel_distance * self.meter_per_pixel #
|
98
|
+
dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
|
99
|
+
pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
|
100
|
+
meters = pixel_distance * self.meter_per_pixel # Convert to meters
|
95
101
|
self.spd[track_id] = int(
|
96
102
|
min((meters / dt) * 3.6, self.max_speed)
|
97
|
-
) #
|
98
|
-
self.locked_ids.add(track_id) #
|
99
|
-
self.trk_hist.pop(track_id, None) #
|
100
|
-
self.trk_frame_ids.pop(track_id, None) #
|
103
|
+
) # Convert to km/h and store final speed
|
104
|
+
self.locked_ids.add(track_id) # Prevent further updates
|
105
|
+
self.trk_hist.pop(track_id, None) # Free memory
|
106
|
+
self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
|
101
107
|
|
102
108
|
if track_id in self.spd:
|
103
109
|
speed_label = f"{self.spd[track_id]} km/h"
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import io
|
4
|
-
from typing import Any
|
4
|
+
from typing import Any, List
|
5
5
|
|
6
6
|
import cv2
|
7
7
|
|
@@ -24,7 +24,7 @@ class Inference:
|
|
24
24
|
model_path (str): Path to the loaded model.
|
25
25
|
model (YOLO): The YOLO model instance.
|
26
26
|
source (str): Selected video source (webcam or video file).
|
27
|
-
enable_trk (
|
27
|
+
enable_trk (bool): Enable tracking option.
|
28
28
|
conf (float): Confidence threshold for detection.
|
29
29
|
iou (float): IoU threshold for non-maximum suppression.
|
30
30
|
org_frame (Any): Container for the original frame to be displayed.
|
@@ -33,14 +33,19 @@ class Inference:
|
|
33
33
|
selected_ind (List[int]): List of selected class indices for detection.
|
34
34
|
|
35
35
|
Methods:
|
36
|
-
web_ui:
|
37
|
-
sidebar:
|
38
|
-
source_upload:
|
39
|
-
configure:
|
40
|
-
inference:
|
36
|
+
web_ui: Set up the Streamlit web interface with custom HTML elements.
|
37
|
+
sidebar: Configure the Streamlit sidebar for model and inference settings.
|
38
|
+
source_upload: Handle video file uploads through the Streamlit interface.
|
39
|
+
configure: Configure the model and load selected classes for inference.
|
40
|
+
inference: Perform real-time object detection inference.
|
41
41
|
|
42
42
|
Examples:
|
43
|
-
|
43
|
+
Create an Inference instance with a custom model
|
44
|
+
>>> inf = Inference(model="path/to/model.pt")
|
45
|
+
>>> inf.inference()
|
46
|
+
|
47
|
+
Create an Inference instance with default settings
|
48
|
+
>>> inf = Inference()
|
44
49
|
>>> inf.inference()
|
45
50
|
"""
|
46
51
|
|
@@ -62,7 +67,7 @@ class Inference:
|
|
62
67
|
self.org_frame = None # Container for the original frame display
|
63
68
|
self.ann_frame = None # Container for the annotated frame display
|
64
69
|
self.vid_file_name = None # Video file name or webcam index
|
65
|
-
self.selected_ind = [] # List of selected class indices for detection
|
70
|
+
self.selected_ind: List[int] = [] # List of selected class indices for detection
|
66
71
|
self.model = None # YOLO model instance
|
67
72
|
|
68
73
|
self.temp_dict = {"model": None, **kwargs}
|
@@ -73,7 +78,7 @@ class Inference:
|
|
73
78
|
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
74
79
|
|
75
80
|
def web_ui(self):
|
76
|
-
"""
|
81
|
+
"""Set up the Streamlit web interface with custom HTML elements."""
|
77
82
|
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
78
83
|
|
79
84
|
# Main title of streamlit application
|
@@ -102,7 +107,7 @@ class Inference:
|
|
102
107
|
"Video",
|
103
108
|
("webcam", "video"),
|
104
109
|
) # Add source selection dropdown
|
105
|
-
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
|
110
|
+
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
|
106
111
|
self.conf = float(
|
107
112
|
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
108
113
|
) # Slider for confidence
|
@@ -166,7 +171,7 @@ class Inference:
|
|
166
171
|
break
|
167
172
|
|
168
173
|
# Process frame with model
|
169
|
-
if self.enable_trk
|
174
|
+
if self.enable_trk:
|
170
175
|
results = self.model.track(
|
171
176
|
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
172
177
|
)
|
@@ -23,9 +23,9 @@ class TrackZone(BaseSolution):
|
|
23
23
|
clss (List[int]): Class indices of tracked objects.
|
24
24
|
|
25
25
|
Methods:
|
26
|
-
process:
|
27
|
-
extract_tracks:
|
28
|
-
display_output:
|
26
|
+
process: Process each frame of the video, applying region-based tracking.
|
27
|
+
extract_tracks: Extract tracking information from the input frame.
|
28
|
+
display_output: Display the processed output.
|
29
29
|
|
30
30
|
Examples:
|
31
31
|
>>> tracker = TrackZone()
|
@@ -82,7 +82,7 @@ class TrackZone(BaseSolution):
|
|
82
82
|
)
|
83
83
|
|
84
84
|
plot_im = annotator.result()
|
85
|
-
self.display_output(plot_im) #
|
85
|
+
self.display_output(plot_im) # Display output with base class function
|
86
86
|
|
87
87
|
# Return a SolutionResults
|
88
88
|
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
@@ -2,6 +2,7 @@
|
|
2
2
|
"""Module defines the base classes and structures for object tracking in YOLO."""
|
3
3
|
|
4
4
|
from collections import OrderedDict
|
5
|
+
from typing import Any
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
|
@@ -66,15 +67,7 @@ class BaseTrack:
|
|
66
67
|
_count = 0
|
67
68
|
|
68
69
|
def __init__(self):
|
69
|
-
"""
|
70
|
-
Initialize a new track with a unique ID and foundational tracking attributes.
|
71
|
-
|
72
|
-
Examples:
|
73
|
-
Initialize a new track
|
74
|
-
>>> track = BaseTrack()
|
75
|
-
>>> print(track.track_id)
|
76
|
-
0
|
77
|
-
"""
|
70
|
+
"""Initialize a new track with a unique ID and foundational tracking attributes."""
|
78
71
|
self.track_id = 0
|
79
72
|
self.is_activated = False
|
80
73
|
self.state = TrackState.New
|
@@ -88,37 +81,37 @@ class BaseTrack:
|
|
88
81
|
self.location = (np.inf, np.inf)
|
89
82
|
|
90
83
|
@property
|
91
|
-
def end_frame(self):
|
92
|
-
"""
|
84
|
+
def end_frame(self) -> int:
|
85
|
+
"""Return the ID of the most recent frame where the object was tracked."""
|
93
86
|
return self.frame_id
|
94
87
|
|
95
88
|
@staticmethod
|
96
|
-
def next_id():
|
89
|
+
def next_id() -> int:
|
97
90
|
"""Increment and return the next unique global track ID for object tracking."""
|
98
91
|
BaseTrack._count += 1
|
99
92
|
return BaseTrack._count
|
100
93
|
|
101
|
-
def activate(self, *args):
|
102
|
-
"""
|
94
|
+
def activate(self, *args: Any) -> None:
|
95
|
+
"""Activate the track with provided arguments, initializing necessary attributes for tracking."""
|
103
96
|
raise NotImplementedError
|
104
97
|
|
105
|
-
def predict(self):
|
106
|
-
"""
|
98
|
+
def predict(self) -> None:
|
99
|
+
"""Predict the next state of the track based on the current state and tracking model."""
|
107
100
|
raise NotImplementedError
|
108
101
|
|
109
|
-
def update(self, *args, **kwargs):
|
110
|
-
"""
|
102
|
+
def update(self, *args: Any, **kwargs: Any) -> None:
|
103
|
+
"""Update the track with new observations and data, modifying its state and attributes accordingly."""
|
111
104
|
raise NotImplementedError
|
112
105
|
|
113
|
-
def mark_lost(self):
|
114
|
-
"""
|
106
|
+
def mark_lost(self) -> None:
|
107
|
+
"""Mark the track as lost by updating its state to TrackState.Lost."""
|
115
108
|
self.state = TrackState.Lost
|
116
109
|
|
117
|
-
def mark_removed(self):
|
118
|
-
"""
|
110
|
+
def mark_removed(self) -> None:
|
111
|
+
"""Mark the track as removed by setting its state to TrackState.Removed."""
|
119
112
|
self.state = TrackState.Removed
|
120
113
|
|
121
114
|
@staticmethod
|
122
|
-
def reset_id():
|
115
|
+
def reset_id() -> None:
|
123
116
|
"""Reset the global track ID counter to its initial value."""
|
124
117
|
BaseTrack._count = 0
|
ultralytics/trackers/bot_sort.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from collections import deque
|
4
|
+
from typing import Any, List, Optional
|
4
5
|
|
5
6
|
import numpy as np
|
6
7
|
import torch
|
@@ -51,7 +52,9 @@ class BOTrack(STrack):
|
|
51
52
|
|
52
53
|
shared_kalman = KalmanFilterXYWH()
|
53
54
|
|
54
|
-
def __init__(
|
55
|
+
def __init__(
|
56
|
+
self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50
|
57
|
+
):
|
55
58
|
"""
|
56
59
|
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
57
60
|
|
@@ -59,7 +62,7 @@ class BOTrack(STrack):
|
|
59
62
|
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
|
60
63
|
score (float): Confidence score of the detection.
|
61
64
|
cls (int): Class ID of the detected object.
|
62
|
-
feat (np.ndarray
|
65
|
+
feat (np.ndarray, optional): Feature vector associated with the detection.
|
63
66
|
feat_history (int): Maximum length of the feature history deque.
|
64
67
|
|
65
68
|
Examples:
|
@@ -79,7 +82,7 @@ class BOTrack(STrack):
|
|
79
82
|
self.features = deque([], maxlen=feat_history)
|
80
83
|
self.alpha = 0.9
|
81
84
|
|
82
|
-
def update_features(self, feat):
|
85
|
+
def update_features(self, feat: np.ndarray) -> None:
|
83
86
|
"""Update the feature vector and apply exponential moving average smoothing."""
|
84
87
|
feat /= np.linalg.norm(feat)
|
85
88
|
self.curr_feat = feat
|
@@ -90,7 +93,7 @@ class BOTrack(STrack):
|
|
90
93
|
self.features.append(feat)
|
91
94
|
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
92
95
|
|
93
|
-
def predict(self):
|
96
|
+
def predict(self) -> None:
|
94
97
|
"""Predict the object's future state using the Kalman filter to update its mean and covariance."""
|
95
98
|
mean_state = self.mean.copy()
|
96
99
|
if self.state != TrackState.Tracked:
|
@@ -99,20 +102,20 @@ class BOTrack(STrack):
|
|
99
102
|
|
100
103
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
101
104
|
|
102
|
-
def re_activate(self, new_track, frame_id, new_id=False):
|
105
|
+
def re_activate(self, new_track: "BOTrack", frame_id: int, new_id: bool = False) -> None:
|
103
106
|
"""Reactivate a track with updated features and optionally assign a new ID."""
|
104
107
|
if new_track.curr_feat is not None:
|
105
108
|
self.update_features(new_track.curr_feat)
|
106
109
|
super().re_activate(new_track, frame_id, new_id)
|
107
110
|
|
108
|
-
def update(self, new_track, frame_id):
|
111
|
+
def update(self, new_track: "BOTrack", frame_id: int) -> None:
|
109
112
|
"""Update the track with new detection information and the current frame ID."""
|
110
113
|
if new_track.curr_feat is not None:
|
111
114
|
self.update_features(new_track.curr_feat)
|
112
115
|
super().update(new_track, frame_id)
|
113
116
|
|
114
117
|
@property
|
115
|
-
def tlwh(self):
|
118
|
+
def tlwh(self) -> np.ndarray:
|
116
119
|
"""Return the current bounding box position in `(top left x, top left y, width, height)` format."""
|
117
120
|
if self.mean is None:
|
118
121
|
return self._tlwh.copy()
|
@@ -121,7 +124,7 @@ class BOTrack(STrack):
|
|
121
124
|
return ret
|
122
125
|
|
123
126
|
@staticmethod
|
124
|
-
def multi_predict(stracks):
|
127
|
+
def multi_predict(stracks: List["BOTrack"]) -> None:
|
125
128
|
"""Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
126
129
|
if len(stracks) <= 0:
|
127
130
|
return
|
@@ -136,12 +139,12 @@ class BOTrack(STrack):
|
|
136
139
|
stracks[i].mean = mean
|
137
140
|
stracks[i].covariance = cov
|
138
141
|
|
139
|
-
def convert_coords(self, tlwh):
|
142
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
140
143
|
"""Convert tlwh bounding box coordinates to xywh format."""
|
141
144
|
return self.tlwh_to_xywh(tlwh)
|
142
145
|
|
143
146
|
@staticmethod
|
144
|
-
def tlwh_to_xywh(tlwh):
|
147
|
+
def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
|
145
148
|
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
146
149
|
ret = np.asarray(tlwh).copy()
|
147
150
|
ret[:2] += ret[2:] / 2
|
@@ -176,12 +179,12 @@ class BOTSORT(BYTETracker):
|
|
176
179
|
The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
|
177
180
|
"""
|
178
181
|
|
179
|
-
def __init__(self, args, frame_rate=30):
|
182
|
+
def __init__(self, args: Any, frame_rate: int = 30):
|
180
183
|
"""
|
181
184
|
Initialize BOTSORT object with ReID module and GMC algorithm.
|
182
185
|
|
183
186
|
Args:
|
184
|
-
args (
|
187
|
+
args (Any): Parsed command-line arguments containing tracking parameters.
|
185
188
|
frame_rate (int): Frame rate of the video being processed.
|
186
189
|
|
187
190
|
Examples:
|
@@ -203,11 +206,13 @@ class BOTSORT(BYTETracker):
|
|
203
206
|
else None
|
204
207
|
)
|
205
208
|
|
206
|
-
def get_kalmanfilter(self):
|
209
|
+
def get_kalmanfilter(self) -> KalmanFilterXYWH:
|
207
210
|
"""Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
208
211
|
return KalmanFilterXYWH()
|
209
212
|
|
210
|
-
def init_track(
|
213
|
+
def init_track(
|
214
|
+
self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
|
215
|
+
) -> List[BOTrack]:
|
211
216
|
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
212
217
|
if len(dets) == 0:
|
213
218
|
return []
|
@@ -217,7 +222,7 @@ class BOTSORT(BYTETracker):
|
|
217
222
|
else:
|
218
223
|
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
219
224
|
|
220
|
-
def get_dists(self, tracks, detections):
|
225
|
+
def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:
|
221
226
|
"""Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
|
222
227
|
dists = matching.iou_distance(tracks, detections)
|
223
228
|
dists_mask = dists > (1 - self.proximity_thresh)
|
@@ -232,11 +237,11 @@ class BOTSORT(BYTETracker):
|
|
232
237
|
dists = np.minimum(dists, emb_dists)
|
233
238
|
return dists
|
234
239
|
|
235
|
-
def multi_predict(self, tracks):
|
240
|
+
def multi_predict(self, tracks: List[BOTrack]) -> None:
|
236
241
|
"""Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
237
242
|
BOTrack.multi_predict(tracks)
|
238
243
|
|
239
|
-
def reset(self):
|
244
|
+
def reset(self) -> None:
|
240
245
|
"""Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
241
246
|
super().reset()
|
242
247
|
self.gmc.reset_params()
|
@@ -245,14 +250,19 @@ class BOTSORT(BYTETracker):
|
|
245
250
|
class ReID:
|
246
251
|
"""YOLO model as encoder for re-identification."""
|
247
252
|
|
248
|
-
def __init__(self, model):
|
249
|
-
"""
|
253
|
+
def __init__(self, model: str):
|
254
|
+
"""
|
255
|
+
Initialize encoder for re-identification.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model (str): Path to the YOLO model for re-identification.
|
259
|
+
"""
|
250
260
|
from ultralytics import YOLO
|
251
261
|
|
252
262
|
self.model = YOLO(model)
|
253
263
|
self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
|
254
264
|
|
255
|
-
def __call__(self, img, dets):
|
265
|
+
def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
|
256
266
|
"""Extract embeddings for detected objects."""
|
257
267
|
feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
|
258
268
|
if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
|