ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -7
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +184 -75
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -20,15 +20,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
20
20
  Initialize trackers for object tracking during prediction.
21
21
 
22
22
  Args:
23
- predictor (object): The predictor object to initialize trackers for.
24
- persist (bool): Whether to persist the trackers if they already exist.
25
-
26
- Raises:
27
- AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
28
- ValueError: If the task is 'classify' as classification doesn't support tracking.
23
+ predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
24
+ persist (bool, optional): Whether to persist the trackers if they already exist.
29
25
 
30
26
  Examples:
31
- Initialize trackers for a predictor object:
27
+ Initialize trackers for a predictor object
32
28
  >>> predictor = SomePredictorClass()
33
29
  >>> on_predict_start(predictor, persist=True)
34
30
  """
@@ -79,7 +75,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
79
75
 
80
76
  Args:
81
77
  predictor (object): The predictor object containing the predictions.
82
- persist (bool): Whether to persist the trackers if they already exist.
78
+ persist (bool, optional): Whether to persist the trackers if they already exist.
83
79
 
84
80
  Examples:
85
81
  Postprocess predictions and update with tracking
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import copy
4
+ from typing import List, Optional
4
5
 
5
6
  import cv2
6
7
  import numpy as np
@@ -19,7 +20,7 @@ class GMC:
19
20
  method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
20
21
  downscale (int): Factor by which to downscale the frames for processing.
21
22
  prevFrame (np.ndarray): Previous frame for tracking.
22
- prevKeyPoints (list): Keypoints from the previous frame.
23
+ prevKeyPoints (List): Keypoints from the previous frame.
23
24
  prevDescriptors (np.ndarray): Descriptors from the previous frame.
24
25
  initializedFirstFrame (bool): Flag indicating if the first frame has been processed.
25
26
 
@@ -88,13 +89,13 @@ class GMC:
88
89
  self.prevDescriptors = None
89
90
  self.initializedFirstFrame = False
90
91
 
91
- def apply(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray:
92
+ def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
92
93
  """
93
94
  Apply object detection on a raw frame using the specified method.
94
95
 
95
96
  Args:
96
97
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
97
- detections (List | None): List of detections to be used in the processing.
98
+ detections (List, optional): List of detections to be used in the processing.
98
99
 
99
100
  Returns:
100
101
  (np.ndarray): Transformation matrix with shape (2, 3).
@@ -136,23 +137,18 @@ class GMC:
136
137
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
137
138
  H = np.eye(2, 3, dtype=np.float32)
138
139
 
139
- # Downscale image
140
+ # Downscale image for computational efficiency
140
141
  if self.downscale > 1.0:
141
142
  frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
142
143
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
143
144
 
144
- # Handle first frame
145
+ # Handle first frame initialization
145
146
  if not self.initializedFirstFrame:
146
- # Initialize data
147
147
  self.prevFrame = frame.copy()
148
-
149
- # Initialization done
150
148
  self.initializedFirstFrame = True
151
-
152
149
  return H
153
150
 
154
- # Run the ECC algorithm. The results are stored in warp_matrix.
155
- # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
151
+ # Run the ECC algorithm to find transformation matrix
156
152
  try:
157
153
  (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
158
154
  except Exception as e:
@@ -160,13 +156,13 @@ class GMC:
160
156
 
161
157
  return H
162
158
 
163
- def apply_features(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray:
159
+ def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
164
160
  """
165
161
  Apply feature-based methods like ORB or SIFT to a raw frame.
166
162
 
167
163
  Args:
168
164
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
169
- detections (List | None): List of detections to be used in the processing.
165
+ detections (List, optional): List of detections to be used in the processing.
170
166
 
171
167
  Returns:
172
168
  (np.ndarray): Transformation matrix with shape (2, 3).
@@ -182,55 +178,50 @@ class GMC:
182
178
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
183
179
  H = np.eye(2, 3)
184
180
 
185
- # Downscale image
181
+ # Downscale image for computational efficiency
186
182
  if self.downscale > 1.0:
187
183
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
188
184
  width = width // self.downscale
189
185
  height = height // self.downscale
190
186
 
191
- # Find the keypoints
187
+ # Create mask for keypoint detection, excluding border regions
192
188
  mask = np.zeros_like(frame)
193
189
  mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
190
+
191
+ # Exclude detection regions from mask to avoid tracking detected objects
194
192
  if detections is not None:
195
193
  for det in detections:
196
194
  tlbr = (det[:4] / self.downscale).astype(np.int_)
197
195
  mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
198
196
 
197
+ # Find keypoints and compute descriptors
199
198
  keypoints = self.detector.detect(frame, mask)
200
-
201
- # Compute the descriptors
202
199
  keypoints, descriptors = self.extractor.compute(frame, keypoints)
203
200
 
204
- # Handle first frame
201
+ # Handle first frame initialization
205
202
  if not self.initializedFirstFrame:
206
- # Initialize data
207
203
  self.prevFrame = frame.copy()
208
204
  self.prevKeyPoints = copy.copy(keypoints)
209
205
  self.prevDescriptors = copy.copy(descriptors)
210
-
211
- # Initialization done
212
206
  self.initializedFirstFrame = True
213
-
214
207
  return H
215
208
 
216
- # Match descriptors
209
+ # Match descriptors between previous and current frame
217
210
  knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
218
211
 
219
- # Filter matches based on smallest spatial distance
212
+ # Filter matches based on spatial distance constraints
220
213
  matches = []
221
214
  spatialDistances = []
222
-
223
215
  maxSpatialDistance = 0.25 * np.array([width, height])
224
216
 
225
217
  # Handle empty matches case
226
218
  if len(knnMatches) == 0:
227
- # Store to next iteration
228
219
  self.prevFrame = frame.copy()
229
220
  self.prevKeyPoints = copy.copy(keypoints)
230
221
  self.prevDescriptors = copy.copy(descriptors)
231
-
232
222
  return H
233
223
 
224
+ # Apply Lowe's ratio test and spatial distance filtering
234
225
  for m, n in knnMatches:
235
226
  if m.distance < 0.9 * n.distance:
236
227
  prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
@@ -247,11 +238,12 @@ class GMC:
247
238
  spatialDistances.append(spatialDistance)
248
239
  matches.append(m)
249
240
 
241
+ # Filter outliers using statistical analysis
250
242
  meanSpatialDistances = np.mean(spatialDistances, 0)
251
243
  stdSpatialDistances = np.std(spatialDistances, 0)
252
-
253
244
  inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
254
245
 
246
+ # Extract good matches and corresponding points
255
247
  goodMatches = []
256
248
  prevPoints = []
257
249
  currPoints = []
@@ -264,39 +256,18 @@ class GMC:
264
256
  prevPoints = np.array(prevPoints)
265
257
  currPoints = np.array(currPoints)
266
258
 
267
- # Draw the keypoint matches on the output image
268
- # if False:
269
- # import matplotlib.pyplot as plt
270
- # matches_img = np.hstack((self.prevFrame, frame))
271
- # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
272
- # W = self.prevFrame.shape[1]
273
- # for m in goodMatches:
274
- # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
275
- # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
276
- # curr_pt[0] += W
277
- # color = np.random.randint(0, 255, 3)
278
- # color = (int(color[0]), int(color[1]), int(color[2]))
279
- #
280
- # matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
281
- # matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
282
- # matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
283
- #
284
- # plt.figure()
285
- # plt.imshow(matches_img)
286
- # plt.show()
287
-
288
- # Find rigid matrix
259
+ # Estimate transformation matrix using RANSAC
289
260
  if prevPoints.shape[0] > 4:
290
261
  H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
291
262
 
292
- # Handle downscale
263
+ # Scale translation components back to original resolution
293
264
  if self.downscale > 1.0:
294
265
  H[0, 2] *= self.downscale
295
266
  H[1, 2] *= self.downscale
296
267
  else:
297
268
  LOGGER.warning("not enough matching points")
298
269
 
299
- # Store to next iteration
270
+ # Store current frame data for next iteration
300
271
  self.prevFrame = frame.copy()
301
272
  self.prevKeyPoints = copy.copy(keypoints)
302
273
  self.prevDescriptors = copy.copy(descriptors)
@@ -324,24 +295,24 @@ class GMC:
324
295
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
325
296
  H = np.eye(2, 3)
326
297
 
327
- # Downscale image
298
+ # Downscale image for computational efficiency
328
299
  if self.downscale > 1.0:
329
300
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
330
301
 
331
- # Find the keypoints
302
+ # Find good features to track
332
303
  keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
333
304
 
334
- # Handle first frame
305
+ # Handle first frame initialization
335
306
  if not self.initializedFirstFrame or self.prevKeyPoints is None:
336
307
  self.prevFrame = frame.copy()
337
308
  self.prevKeyPoints = copy.copy(keypoints)
338
309
  self.initializedFirstFrame = True
339
310
  return H
340
311
 
341
- # Find correspondences
312
+ # Calculate optical flow using Lucas-Kanade method
342
313
  matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
343
314
 
344
- # Leave good correspondences only
315
+ # Extract successfully tracked points
345
316
  prevPoints = []
346
317
  currPoints = []
347
318
 
@@ -353,16 +324,18 @@ class GMC:
353
324
  prevPoints = np.array(prevPoints)
354
325
  currPoints = np.array(currPoints)
355
326
 
356
- # Find rigid matrix
327
+ # Estimate transformation matrix using RANSAC
357
328
  if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):
358
329
  H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
359
330
 
331
+ # Scale translation components back to original resolution
360
332
  if self.downscale > 1.0:
361
333
  H[0, 2] *= self.downscale
362
334
  H[1, 2] *= self.downscale
363
335
  else:
364
336
  LOGGER.warning("not enough matching points")
365
337
 
338
+ # Store current frame data for next iteration
366
339
  self.prevFrame = frame.copy()
367
340
  self.prevKeyPoints = copy.copy(keypoints)
368
341
 
@@ -20,12 +20,12 @@ class KalmanFilterXYAH:
20
20
  _std_weight_velocity (float): Standard deviation weight for velocity.
21
21
 
22
22
  Methods:
23
- initiate: Creates a track from an unassociated measurement.
24
- predict: Runs the Kalman filter prediction step.
25
- project: Projects the state distribution to measurement space.
26
- multi_predict: Runs the Kalman filter prediction step (vectorized version).
27
- update: Runs the Kalman filter correction step.
28
- gating_distance: Computes the gating distance between state distribution and measurements.
23
+ initiate: Create a track from an unassociated measurement.
24
+ predict: Run the Kalman filter prediction step.
25
+ project: Project the state distribution to measurement space.
26
+ multi_predict: Run the Kalman filter prediction step (vectorized version).
27
+ update: Run the Kalman filter correction step.
28
+ gating_distance: Compute the gating distance between state distribution and measurements.
29
29
 
30
30
  Examples:
31
31
  Initialize the Kalman filter and create a track from a measurement
@@ -70,8 +70,8 @@ class KalmanFilterXYAH:
70
70
  and height h.
71
71
 
72
72
  Returns:
73
- (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
74
- (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
73
+ mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
74
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
75
75
 
76
76
  Examples:
77
77
  >>> kf = KalmanFilterXYAH()
@@ -104,8 +104,8 @@ class KalmanFilterXYAH:
104
104
  covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
105
105
 
106
106
  Returns:
107
- (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
108
- (np.ndarray): Covariance matrix of the predicted state.
107
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
108
+ covariance (np.ndarray): Covariance matrix of the predicted state.
109
109
 
110
110
  Examples:
111
111
  >>> kf = KalmanFilterXYAH()
@@ -141,8 +141,8 @@ class KalmanFilterXYAH:
141
141
  covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
142
142
 
143
143
  Returns:
144
- (np.ndarray): Projected mean of the given state estimate.
145
- (np.ndarray): Projected covariance matrix of the given state estimate.
144
+ mean (np.ndarray): Projected mean of the given state estimate.
145
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
146
146
 
147
147
  Examples:
148
148
  >>> kf = KalmanFilterXYAH()
@@ -171,8 +171,8 @@ class KalmanFilterXYAH:
171
171
  covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
172
172
 
173
173
  Returns:
174
- (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
175
- (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
174
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
175
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
176
176
 
177
177
  Examples:
178
178
  >>> mean = np.random.rand(10, 8) # 10 object states
@@ -213,8 +213,8 @@ class KalmanFilterXYAH:
213
213
  position, a the aspect ratio, and h the height of the bounding box.
214
214
 
215
215
  Returns:
216
- (np.ndarray): Measurement-corrected state mean.
217
- (np.ndarray): Measurement-corrected state covariance.
216
+ new_mean (np.ndarray): Measurement-corrected state mean.
217
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
218
218
 
219
219
  Examples:
220
220
  >>> kf = KalmanFilterXYAH()
@@ -254,8 +254,8 @@ class KalmanFilterXYAH:
254
254
  covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
255
255
  measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
256
256
  bounding box center position, a the aspect ratio, and h the height.
257
- only_position (bool): If True, distance computation is done with respect to box center position only.
258
- metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
257
+ only_position (bool, optional): If True, distance computation is done with respect to box center position only.
258
+ metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
259
259
  Euclidean distance and 'maha' for the squared Mahalanobis distance.
260
260
 
261
261
  Returns:
@@ -302,11 +302,11 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
302
302
  _std_weight_velocity (float): Standard deviation weight for velocity.
303
303
 
304
304
  Methods:
305
- initiate: Creates a track from an unassociated measurement.
306
- predict: Runs the Kalman filter prediction step.
307
- project: Projects the state distribution to measurement space.
308
- multi_predict: Runs the Kalman filter prediction step in a vectorized manner.
309
- update: Runs the Kalman filter correction step.
305
+ initiate: Create a track from an unassociated measurement.
306
+ predict: Run the Kalman filter prediction step.
307
+ project: Project the state distribution to measurement space.
308
+ multi_predict: Run the Kalman filter prediction step in a vectorized manner.
309
+ update: Run the Kalman filter correction step.
310
310
 
311
311
  Examples:
312
312
  Create a Kalman filter and initialize a track
@@ -325,8 +325,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
325
325
  measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
326
326
 
327
327
  Returns:
328
- (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
329
- (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
328
+ mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
329
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
330
330
 
331
331
  Examples:
332
332
  >>> kf = KalmanFilterXYWH()
@@ -361,7 +361,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
361
361
  covariance = np.diag(np.square(std))
362
362
  return mean, covariance
363
363
 
364
- def predict(self, mean, covariance):
364
+ def predict(self, mean: np.ndarray, covariance: np.ndarray):
365
365
  """
366
366
  Run Kalman filter prediction step.
367
367
 
@@ -370,8 +370,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
370
370
  covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
371
371
 
372
372
  Returns:
373
- (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
374
- (np.ndarray): Covariance matrix of the predicted state.
373
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
374
+ covariance (np.ndarray): Covariance matrix of the predicted state.
375
375
 
376
376
  Examples:
377
377
  >>> kf = KalmanFilterXYWH()
@@ -398,7 +398,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
398
398
 
399
399
  return mean, covariance
400
400
 
401
- def project(self, mean, covariance):
401
+ def project(self, mean: np.ndarray, covariance: np.ndarray):
402
402
  """
403
403
  Project state distribution to measurement space.
404
404
 
@@ -407,8 +407,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
407
407
  covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
408
408
 
409
409
  Returns:
410
- (np.ndarray): Projected mean of the given state estimate.
411
- (np.ndarray): Projected covariance matrix of the given state estimate.
410
+ mean (np.ndarray): Projected mean of the given state estimate.
411
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
412
412
 
413
413
  Examples:
414
414
  >>> kf = KalmanFilterXYWH()
@@ -428,7 +428,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
428
428
  covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
429
429
  return mean, covariance + innovation_cov
430
430
 
431
- def multi_predict(self, mean, covariance):
431
+ def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
432
432
  """
433
433
  Run Kalman filter prediction step (Vectorized version).
434
434
 
@@ -437,8 +437,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
437
437
  covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
438
438
 
439
439
  Returns:
440
- (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
441
- (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
440
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
441
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
442
442
 
443
443
  Examples:
444
444
  >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
@@ -469,7 +469,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
469
469
 
470
470
  return mean, covariance
471
471
 
472
- def update(self, mean, covariance, measurement):
472
+ def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
473
473
  """
474
474
  Run Kalman filter correction step.
475
475
 
@@ -480,8 +480,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
480
480
  position, w the width, and h the height of the bounding box.
481
481
 
482
482
  Returns:
483
- (np.ndarray): Measurement-corrected state mean.
484
- (np.ndarray): Measurement-corrected state covariance.
483
+ new_mean (np.ndarray): Measurement-corrected state mean.
484
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
485
485
 
486
486
  Examples:
487
487
  >>> kf = KalmanFilterXYWH()
@@ -17,7 +17,7 @@ except (ImportError, AssertionError, AttributeError):
17
17
  import lap
18
18
 
19
19
 
20
- def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
20
+ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
21
21
  """
22
22
  Perform linear assignment using either the scipy or lap.lapjv method.
23
23