ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/trackers/track.py
CHANGED
@@ -20,15 +20,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|
20
20
|
Initialize trackers for object tracking during prediction.
|
21
21
|
|
22
22
|
Args:
|
23
|
-
predictor (
|
24
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
25
|
-
|
26
|
-
Raises:
|
27
|
-
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
28
|
-
ValueError: If the task is 'classify' as classification doesn't support tracking.
|
23
|
+
predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
|
24
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
29
25
|
|
30
26
|
Examples:
|
31
|
-
Initialize trackers for a predictor object
|
27
|
+
Initialize trackers for a predictor object
|
32
28
|
>>> predictor = SomePredictorClass()
|
33
29
|
>>> on_predict_start(predictor, persist=True)
|
34
30
|
"""
|
@@ -79,7 +75,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|
79
75
|
|
80
76
|
Args:
|
81
77
|
predictor (object): The predictor object containing the predictions.
|
82
|
-
persist (bool): Whether to persist the trackers if they already exist.
|
78
|
+
persist (bool, optional): Whether to persist the trackers if they already exist.
|
83
79
|
|
84
80
|
Examples:
|
85
81
|
Postprocess predictions and update with tracking
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import copy
|
4
|
+
from typing import List, Optional
|
4
5
|
|
5
6
|
import cv2
|
6
7
|
import numpy as np
|
@@ -19,7 +20,7 @@ class GMC:
|
|
19
20
|
method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
20
21
|
downscale (int): Factor by which to downscale the frames for processing.
|
21
22
|
prevFrame (np.ndarray): Previous frame for tracking.
|
22
|
-
prevKeyPoints (
|
23
|
+
prevKeyPoints (List): Keypoints from the previous frame.
|
23
24
|
prevDescriptors (np.ndarray): Descriptors from the previous frame.
|
24
25
|
initializedFirstFrame (bool): Flag indicating if the first frame has been processed.
|
25
26
|
|
@@ -88,13 +89,13 @@ class GMC:
|
|
88
89
|
self.prevDescriptors = None
|
89
90
|
self.initializedFirstFrame = False
|
90
91
|
|
91
|
-
def apply(self, raw_frame: np.ndarray, detections:
|
92
|
+
def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
|
92
93
|
"""
|
93
94
|
Apply object detection on a raw frame using the specified method.
|
94
95
|
|
95
96
|
Args:
|
96
97
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
97
|
-
detections (List
|
98
|
+
detections (List, optional): List of detections to be used in the processing.
|
98
99
|
|
99
100
|
Returns:
|
100
101
|
(np.ndarray): Transformation matrix with shape (2, 3).
|
@@ -136,23 +137,18 @@ class GMC:
|
|
136
137
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
137
138
|
H = np.eye(2, 3, dtype=np.float32)
|
138
139
|
|
139
|
-
# Downscale image
|
140
|
+
# Downscale image for computational efficiency
|
140
141
|
if self.downscale > 1.0:
|
141
142
|
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
142
143
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
143
144
|
|
144
|
-
# Handle first frame
|
145
|
+
# Handle first frame initialization
|
145
146
|
if not self.initializedFirstFrame:
|
146
|
-
# Initialize data
|
147
147
|
self.prevFrame = frame.copy()
|
148
|
-
|
149
|
-
# Initialization done
|
150
148
|
self.initializedFirstFrame = True
|
151
|
-
|
152
149
|
return H
|
153
150
|
|
154
|
-
# Run the ECC algorithm
|
155
|
-
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
151
|
+
# Run the ECC algorithm to find transformation matrix
|
156
152
|
try:
|
157
153
|
(_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
158
154
|
except Exception as e:
|
@@ -160,13 +156,13 @@ class GMC:
|
|
160
156
|
|
161
157
|
return H
|
162
158
|
|
163
|
-
def apply_features(self, raw_frame: np.ndarray, detections:
|
159
|
+
def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
|
164
160
|
"""
|
165
161
|
Apply feature-based methods like ORB or SIFT to a raw frame.
|
166
162
|
|
167
163
|
Args:
|
168
164
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
169
|
-
detections (List
|
165
|
+
detections (List, optional): List of detections to be used in the processing.
|
170
166
|
|
171
167
|
Returns:
|
172
168
|
(np.ndarray): Transformation matrix with shape (2, 3).
|
@@ -182,55 +178,50 @@ class GMC:
|
|
182
178
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
183
179
|
H = np.eye(2, 3)
|
184
180
|
|
185
|
-
# Downscale image
|
181
|
+
# Downscale image for computational efficiency
|
186
182
|
if self.downscale > 1.0:
|
187
183
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
188
184
|
width = width // self.downscale
|
189
185
|
height = height // self.downscale
|
190
186
|
|
191
|
-
#
|
187
|
+
# Create mask for keypoint detection, excluding border regions
|
192
188
|
mask = np.zeros_like(frame)
|
193
189
|
mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
|
190
|
+
|
191
|
+
# Exclude detection regions from mask to avoid tracking detected objects
|
194
192
|
if detections is not None:
|
195
193
|
for det in detections:
|
196
194
|
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
197
195
|
mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
|
198
196
|
|
197
|
+
# Find keypoints and compute descriptors
|
199
198
|
keypoints = self.detector.detect(frame, mask)
|
200
|
-
|
201
|
-
# Compute the descriptors
|
202
199
|
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
203
200
|
|
204
|
-
# Handle first frame
|
201
|
+
# Handle first frame initialization
|
205
202
|
if not self.initializedFirstFrame:
|
206
|
-
# Initialize data
|
207
203
|
self.prevFrame = frame.copy()
|
208
204
|
self.prevKeyPoints = copy.copy(keypoints)
|
209
205
|
self.prevDescriptors = copy.copy(descriptors)
|
210
|
-
|
211
|
-
# Initialization done
|
212
206
|
self.initializedFirstFrame = True
|
213
|
-
|
214
207
|
return H
|
215
208
|
|
216
|
-
# Match descriptors
|
209
|
+
# Match descriptors between previous and current frame
|
217
210
|
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
218
211
|
|
219
|
-
# Filter matches based on
|
212
|
+
# Filter matches based on spatial distance constraints
|
220
213
|
matches = []
|
221
214
|
spatialDistances = []
|
222
|
-
|
223
215
|
maxSpatialDistance = 0.25 * np.array([width, height])
|
224
216
|
|
225
217
|
# Handle empty matches case
|
226
218
|
if len(knnMatches) == 0:
|
227
|
-
# Store to next iteration
|
228
219
|
self.prevFrame = frame.copy()
|
229
220
|
self.prevKeyPoints = copy.copy(keypoints)
|
230
221
|
self.prevDescriptors = copy.copy(descriptors)
|
231
|
-
|
232
222
|
return H
|
233
223
|
|
224
|
+
# Apply Lowe's ratio test and spatial distance filtering
|
234
225
|
for m, n in knnMatches:
|
235
226
|
if m.distance < 0.9 * n.distance:
|
236
227
|
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
@@ -247,11 +238,12 @@ class GMC:
|
|
247
238
|
spatialDistances.append(spatialDistance)
|
248
239
|
matches.append(m)
|
249
240
|
|
241
|
+
# Filter outliers using statistical analysis
|
250
242
|
meanSpatialDistances = np.mean(spatialDistances, 0)
|
251
243
|
stdSpatialDistances = np.std(spatialDistances, 0)
|
252
|
-
|
253
244
|
inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
254
245
|
|
246
|
+
# Extract good matches and corresponding points
|
255
247
|
goodMatches = []
|
256
248
|
prevPoints = []
|
257
249
|
currPoints = []
|
@@ -264,39 +256,18 @@ class GMC:
|
|
264
256
|
prevPoints = np.array(prevPoints)
|
265
257
|
currPoints = np.array(currPoints)
|
266
258
|
|
267
|
-
#
|
268
|
-
# if False:
|
269
|
-
# import matplotlib.pyplot as plt
|
270
|
-
# matches_img = np.hstack((self.prevFrame, frame))
|
271
|
-
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
272
|
-
# W = self.prevFrame.shape[1]
|
273
|
-
# for m in goodMatches:
|
274
|
-
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
275
|
-
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
276
|
-
# curr_pt[0] += W
|
277
|
-
# color = np.random.randint(0, 255, 3)
|
278
|
-
# color = (int(color[0]), int(color[1]), int(color[2]))
|
279
|
-
#
|
280
|
-
# matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
281
|
-
# matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
282
|
-
# matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
283
|
-
#
|
284
|
-
# plt.figure()
|
285
|
-
# plt.imshow(matches_img)
|
286
|
-
# plt.show()
|
287
|
-
|
288
|
-
# Find rigid matrix
|
259
|
+
# Estimate transformation matrix using RANSAC
|
289
260
|
if prevPoints.shape[0] > 4:
|
290
261
|
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
291
262
|
|
292
|
-
#
|
263
|
+
# Scale translation components back to original resolution
|
293
264
|
if self.downscale > 1.0:
|
294
265
|
H[0, 2] *= self.downscale
|
295
266
|
H[1, 2] *= self.downscale
|
296
267
|
else:
|
297
268
|
LOGGER.warning("not enough matching points")
|
298
269
|
|
299
|
-
# Store
|
270
|
+
# Store current frame data for next iteration
|
300
271
|
self.prevFrame = frame.copy()
|
301
272
|
self.prevKeyPoints = copy.copy(keypoints)
|
302
273
|
self.prevDescriptors = copy.copy(descriptors)
|
@@ -324,24 +295,24 @@ class GMC:
|
|
324
295
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
|
325
296
|
H = np.eye(2, 3)
|
326
297
|
|
327
|
-
# Downscale image
|
298
|
+
# Downscale image for computational efficiency
|
328
299
|
if self.downscale > 1.0:
|
329
300
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
330
301
|
|
331
|
-
# Find
|
302
|
+
# Find good features to track
|
332
303
|
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
333
304
|
|
334
|
-
# Handle first frame
|
305
|
+
# Handle first frame initialization
|
335
306
|
if not self.initializedFirstFrame or self.prevKeyPoints is None:
|
336
307
|
self.prevFrame = frame.copy()
|
337
308
|
self.prevKeyPoints = copy.copy(keypoints)
|
338
309
|
self.initializedFirstFrame = True
|
339
310
|
return H
|
340
311
|
|
341
|
-
#
|
312
|
+
# Calculate optical flow using Lucas-Kanade method
|
342
313
|
matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
343
314
|
|
344
|
-
#
|
315
|
+
# Extract successfully tracked points
|
345
316
|
prevPoints = []
|
346
317
|
currPoints = []
|
347
318
|
|
@@ -353,16 +324,18 @@ class GMC:
|
|
353
324
|
prevPoints = np.array(prevPoints)
|
354
325
|
currPoints = np.array(currPoints)
|
355
326
|
|
356
|
-
#
|
327
|
+
# Estimate transformation matrix using RANSAC
|
357
328
|
if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):
|
358
329
|
H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
359
330
|
|
331
|
+
# Scale translation components back to original resolution
|
360
332
|
if self.downscale > 1.0:
|
361
333
|
H[0, 2] *= self.downscale
|
362
334
|
H[1, 2] *= self.downscale
|
363
335
|
else:
|
364
336
|
LOGGER.warning("not enough matching points")
|
365
337
|
|
338
|
+
# Store current frame data for next iteration
|
366
339
|
self.prevFrame = frame.copy()
|
367
340
|
self.prevKeyPoints = copy.copy(keypoints)
|
368
341
|
|
@@ -20,12 +20,12 @@ class KalmanFilterXYAH:
|
|
20
20
|
_std_weight_velocity (float): Standard deviation weight for velocity.
|
21
21
|
|
22
22
|
Methods:
|
23
|
-
initiate:
|
24
|
-
predict:
|
25
|
-
project:
|
26
|
-
multi_predict:
|
27
|
-
update:
|
28
|
-
gating_distance:
|
23
|
+
initiate: Create a track from an unassociated measurement.
|
24
|
+
predict: Run the Kalman filter prediction step.
|
25
|
+
project: Project the state distribution to measurement space.
|
26
|
+
multi_predict: Run the Kalman filter prediction step (vectorized version).
|
27
|
+
update: Run the Kalman filter correction step.
|
28
|
+
gating_distance: Compute the gating distance between state distribution and measurements.
|
29
29
|
|
30
30
|
Examples:
|
31
31
|
Initialize the Kalman filter and create a track from a measurement
|
@@ -70,8 +70,8 @@ class KalmanFilterXYAH:
|
|
70
70
|
and height h.
|
71
71
|
|
72
72
|
Returns:
|
73
|
-
(np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
|
74
|
-
(np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
|
73
|
+
mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
|
74
|
+
covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
|
75
75
|
|
76
76
|
Examples:
|
77
77
|
>>> kf = KalmanFilterXYAH()
|
@@ -104,8 +104,8 @@ class KalmanFilterXYAH:
|
|
104
104
|
covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
105
105
|
|
106
106
|
Returns:
|
107
|
-
(np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
|
108
|
-
(np.ndarray): Covariance matrix of the predicted state.
|
107
|
+
mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
|
108
|
+
covariance (np.ndarray): Covariance matrix of the predicted state.
|
109
109
|
|
110
110
|
Examples:
|
111
111
|
>>> kf = KalmanFilterXYAH()
|
@@ -141,8 +141,8 @@ class KalmanFilterXYAH:
|
|
141
141
|
covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
|
142
142
|
|
143
143
|
Returns:
|
144
|
-
(np.ndarray): Projected mean of the given state estimate.
|
145
|
-
(np.ndarray): Projected covariance matrix of the given state estimate.
|
144
|
+
mean (np.ndarray): Projected mean of the given state estimate.
|
145
|
+
covariance (np.ndarray): Projected covariance matrix of the given state estimate.
|
146
146
|
|
147
147
|
Examples:
|
148
148
|
>>> kf = KalmanFilterXYAH()
|
@@ -171,8 +171,8 @@ class KalmanFilterXYAH:
|
|
171
171
|
covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
172
172
|
|
173
173
|
Returns:
|
174
|
-
(np.ndarray): Mean matrix of the predicted states with shape (N, 8).
|
175
|
-
(np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
|
174
|
+
mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
|
175
|
+
covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
|
176
176
|
|
177
177
|
Examples:
|
178
178
|
>>> mean = np.random.rand(10, 8) # 10 object states
|
@@ -213,8 +213,8 @@ class KalmanFilterXYAH:
|
|
213
213
|
position, a the aspect ratio, and h the height of the bounding box.
|
214
214
|
|
215
215
|
Returns:
|
216
|
-
(np.ndarray): Measurement-corrected state mean.
|
217
|
-
(np.ndarray): Measurement-corrected state covariance.
|
216
|
+
new_mean (np.ndarray): Measurement-corrected state mean.
|
217
|
+
new_covariance (np.ndarray): Measurement-corrected state covariance.
|
218
218
|
|
219
219
|
Examples:
|
220
220
|
>>> kf = KalmanFilterXYAH()
|
@@ -254,8 +254,8 @@ class KalmanFilterXYAH:
|
|
254
254
|
covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
|
255
255
|
measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
|
256
256
|
bounding box center position, a the aspect ratio, and h the height.
|
257
|
-
only_position (bool): If True, distance computation is done with respect to box center position only.
|
258
|
-
metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
|
257
|
+
only_position (bool, optional): If True, distance computation is done with respect to box center position only.
|
258
|
+
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
|
259
259
|
Euclidean distance and 'maha' for the squared Mahalanobis distance.
|
260
260
|
|
261
261
|
Returns:
|
@@ -302,11 +302,11 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
302
302
|
_std_weight_velocity (float): Standard deviation weight for velocity.
|
303
303
|
|
304
304
|
Methods:
|
305
|
-
initiate:
|
306
|
-
predict:
|
307
|
-
project:
|
308
|
-
multi_predict:
|
309
|
-
update:
|
305
|
+
initiate: Create a track from an unassociated measurement.
|
306
|
+
predict: Run the Kalman filter prediction step.
|
307
|
+
project: Project the state distribution to measurement space.
|
308
|
+
multi_predict: Run the Kalman filter prediction step in a vectorized manner.
|
309
|
+
update: Run the Kalman filter correction step.
|
310
310
|
|
311
311
|
Examples:
|
312
312
|
Create a Kalman filter and initialize a track
|
@@ -325,8 +325,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
325
325
|
measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
|
326
326
|
|
327
327
|
Returns:
|
328
|
-
(np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
|
329
|
-
(np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
|
328
|
+
mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
|
329
|
+
covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
|
330
330
|
|
331
331
|
Examples:
|
332
332
|
>>> kf = KalmanFilterXYWH()
|
@@ -361,7 +361,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
361
361
|
covariance = np.diag(np.square(std))
|
362
362
|
return mean, covariance
|
363
363
|
|
364
|
-
def predict(self, mean, covariance):
|
364
|
+
def predict(self, mean: np.ndarray, covariance: np.ndarray):
|
365
365
|
"""
|
366
366
|
Run Kalman filter prediction step.
|
367
367
|
|
@@ -370,8 +370,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
370
370
|
covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
371
371
|
|
372
372
|
Returns:
|
373
|
-
(np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
|
374
|
-
(np.ndarray): Covariance matrix of the predicted state.
|
373
|
+
mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
|
374
|
+
covariance (np.ndarray): Covariance matrix of the predicted state.
|
375
375
|
|
376
376
|
Examples:
|
377
377
|
>>> kf = KalmanFilterXYWH()
|
@@ -398,7 +398,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
398
398
|
|
399
399
|
return mean, covariance
|
400
400
|
|
401
|
-
def project(self, mean, covariance):
|
401
|
+
def project(self, mean: np.ndarray, covariance: np.ndarray):
|
402
402
|
"""
|
403
403
|
Project state distribution to measurement space.
|
404
404
|
|
@@ -407,8 +407,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
407
407
|
covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
|
408
408
|
|
409
409
|
Returns:
|
410
|
-
(np.ndarray): Projected mean of the given state estimate.
|
411
|
-
(np.ndarray): Projected covariance matrix of the given state estimate.
|
410
|
+
mean (np.ndarray): Projected mean of the given state estimate.
|
411
|
+
covariance (np.ndarray): Projected covariance matrix of the given state estimate.
|
412
412
|
|
413
413
|
Examples:
|
414
414
|
>>> kf = KalmanFilterXYWH()
|
@@ -428,7 +428,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
428
428
|
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
429
429
|
return mean, covariance + innovation_cov
|
430
430
|
|
431
|
-
def multi_predict(self, mean, covariance):
|
431
|
+
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
|
432
432
|
"""
|
433
433
|
Run Kalman filter prediction step (Vectorized version).
|
434
434
|
|
@@ -437,8 +437,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
437
437
|
covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
438
438
|
|
439
439
|
Returns:
|
440
|
-
(np.ndarray): Mean matrix of the predicted states with shape (N, 8).
|
441
|
-
(np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
|
440
|
+
mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
|
441
|
+
covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
|
442
442
|
|
443
443
|
Examples:
|
444
444
|
>>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
|
@@ -469,7 +469,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
469
469
|
|
470
470
|
return mean, covariance
|
471
471
|
|
472
|
-
def update(self, mean, covariance, measurement):
|
472
|
+
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
|
473
473
|
"""
|
474
474
|
Run Kalman filter correction step.
|
475
475
|
|
@@ -480,8 +480,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
480
480
|
position, w the width, and h the height of the bounding box.
|
481
481
|
|
482
482
|
Returns:
|
483
|
-
(np.ndarray): Measurement-corrected state mean.
|
484
|
-
(np.ndarray): Measurement-corrected state covariance.
|
483
|
+
new_mean (np.ndarray): Measurement-corrected state mean.
|
484
|
+
new_covariance (np.ndarray): Measurement-corrected state covariance.
|
485
485
|
|
486
486
|
Examples:
|
487
487
|
>>> kf = KalmanFilterXYWH()
|
@@ -17,7 +17,7 @@ except (ImportError, AssertionError, AttributeError):
|
|
17
17
|
import lap
|
18
18
|
|
19
19
|
|
20
|
-
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True)
|
20
|
+
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
|
21
21
|
"""
|
22
22
|
Perform linear assignment using either the scipy or lap.lapjv method.
|
23
23
|
|