dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,7 @@ from .utils.kalman_filter import KalmanFilterXYAH
14
14
 
15
15
 
16
16
  class STrack(BaseTrack):
17
- """
18
- Single object tracking representation that uses Kalman filtering for state estimation.
17
+ """Single object tracking representation that uses Kalman filtering for state estimation.
19
18
 
20
19
  This class is responsible for storing all the information regarding individual tracklets and performs state updates
21
20
  and predictions based on Kalman filter.
@@ -54,12 +53,11 @@ class STrack(BaseTrack):
54
53
  shared_kalman = KalmanFilterXYAH()
55
54
 
56
55
  def __init__(self, xywh: list[float], score: float, cls: Any):
57
- """
58
- Initialize a new STrack instance.
56
+ """Initialize a new STrack instance.
59
57
 
60
58
  Args:
61
- xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
62
- (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
59
+ xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where (x,
60
+ y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
63
61
  score (float): Confidence score of the detection.
64
62
  cls (Any): Class label for the detected object.
65
63
 
@@ -154,8 +152,7 @@ class STrack(BaseTrack):
154
152
  self.idx = new_track.idx
155
153
 
156
154
  def update(self, new_track: STrack, frame_id: int):
157
- """
158
- Update the state of a matched track.
155
+ """Update the state of a matched track.
159
156
 
160
157
  Args:
161
158
  new_track (STrack): The new track containing updated information.
@@ -238,12 +235,11 @@ class STrack(BaseTrack):
238
235
 
239
236
 
240
237
  class BYTETracker:
241
- """
242
- BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
238
+ """BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
243
239
 
244
- This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a
245
- video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
246
- predicting the new object locations, and performs data association.
240
+ This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
241
+ in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
242
+ filtering for predicting the new object locations, and performs data association.
247
243
 
248
244
  Attributes:
249
245
  tracked_stracks (list[STrack]): List of successfully activated tracks.
@@ -274,8 +270,7 @@ class BYTETracker:
274
270
  """
275
271
 
276
272
  def __init__(self, args, frame_rate: int = 30):
277
- """
278
- Initialize a BYTETracker instance for object tracking.
273
+ """Initialize a BYTETracker instance for object tracking.
279
274
 
280
275
  Args:
281
276
  args (Namespace): Command-line arguments containing tracking parameters.
@@ -16,8 +16,7 @@ TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
16
16
 
17
17
 
18
18
  def on_predict_start(predictor: object, persist: bool = False) -> None:
19
- """
20
- Initialize trackers for object tracking during prediction.
19
+ """Initialize trackers for object tracking during prediction.
21
20
 
22
21
  Args:
23
22
  predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
@@ -70,8 +69,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
70
69
 
71
70
 
72
71
  def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
73
- """
74
- Postprocess detected boxes and update with object tracking.
72
+ """Postprocess detected boxes and update with object tracking.
75
73
 
76
74
  Args:
77
75
  predictor (object): The predictor object containing the predictions.
@@ -103,8 +101,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
103
101
 
104
102
 
105
103
  def register_tracker(model: object, persist: bool) -> None:
106
- """
107
- Register tracking callbacks to the model for object tracking during prediction.
104
+ """Register tracking callbacks to the model for object tracking during prediction.
108
105
 
109
106
  Args:
110
107
  model (object): The model object to register tracking callbacks for.
@@ -11,8 +11,7 @@ from ultralytics.utils import LOGGER
11
11
 
12
12
 
13
13
  class GMC:
14
- """
15
- Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
14
+ """Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
16
15
 
17
16
  This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
18
17
  SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
@@ -43,8 +42,7 @@ class GMC:
43
42
  """
44
43
 
45
44
  def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
46
- """
47
- Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
45
+ """Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
48
46
 
49
47
  Args:
50
48
  method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
@@ -91,8 +89,7 @@ class GMC:
91
89
  self.initializedFirstFrame = False
92
90
 
93
91
  def apply(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
94
- """
95
- Apply object detection on a raw frame using the specified method.
92
+ """Apply object detection on a raw frame using the specified method.
96
93
 
97
94
  Args:
98
95
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -118,8 +115,7 @@ class GMC:
118
115
  return np.eye(2, 3)
119
116
 
120
117
  def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray:
121
- """
122
- Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
118
+ """Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
123
119
 
124
120
  Args:
125
121
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -158,8 +154,7 @@ class GMC:
158
154
  return H
159
155
 
160
156
  def apply_features(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
161
- """
162
- Apply feature-based methods like ORB or SIFT to a raw frame.
157
+ """Apply feature-based methods like ORB or SIFT to a raw frame.
163
158
 
164
159
  Args:
165
160
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -276,8 +271,7 @@ class GMC:
276
271
  return H
277
272
 
278
273
  def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray:
279
- """
280
- Apply Sparse Optical Flow method to a raw frame.
274
+ """Apply Sparse Optical Flow method to a raw frame.
281
275
 
282
276
  Args:
283
277
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -5,11 +5,10 @@ import scipy.linalg
5
5
 
6
6
 
7
7
  class KalmanFilterXYAH:
8
- """
9
- A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
8
+ """A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
10
9
 
11
- Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
12
- (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
10
+ Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space (x, y,
11
+ a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
13
12
  respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
14
13
  taken as a direct observation of the state space (linear observation model).
15
14
 
@@ -37,8 +36,7 @@ class KalmanFilterXYAH:
37
36
  """
38
37
 
39
38
  def __init__(self):
40
- """
41
- Initialize Kalman filter model matrices with motion and observation uncertainty weights.
39
+ """Initialize Kalman filter model matrices with motion and observation uncertainty weights.
42
40
 
43
41
  The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
44
42
  represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
@@ -62,15 +60,15 @@ class KalmanFilterXYAH:
62
60
  self._std_weight_velocity = 1.0 / 160
63
61
 
64
62
  def initiate(self, measurement: np.ndarray):
65
- """
66
- Create a track from an unassociated measurement.
63
+ """Create a track from an unassociated measurement.
67
64
 
68
65
  Args:
69
66
  measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
70
67
  and height h.
71
68
 
72
69
  Returns:
73
- mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
70
+ mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0
71
+ mean.
74
72
  covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
75
73
 
76
74
  Examples:
@@ -96,12 +94,12 @@ class KalmanFilterXYAH:
96
94
  return mean, covariance
97
95
 
98
96
  def predict(self, mean: np.ndarray, covariance: np.ndarray):
99
- """
100
- Run Kalman filter prediction step.
97
+ """Run Kalman filter prediction step.
101
98
 
102
99
  Args:
103
100
  mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
104
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
101
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
102
+ step.
105
103
 
106
104
  Returns:
107
105
  mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
@@ -133,8 +131,7 @@ class KalmanFilterXYAH:
133
131
  return mean, covariance
134
132
 
135
133
  def project(self, mean: np.ndarray, covariance: np.ndarray):
136
- """
137
- Project state distribution to measurement space.
134
+ """Project state distribution to measurement space.
138
135
 
139
136
  Args:
140
137
  mean (np.ndarray): The state's mean vector (8 dimensional array).
@@ -163,8 +160,7 @@ class KalmanFilterXYAH:
163
160
  return mean, covariance + innovation_cov
164
161
 
165
162
  def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
166
- """
167
- Run Kalman filter prediction step for multiple object states (Vectorized version).
163
+ """Run Kalman filter prediction step for multiple object states (Vectorized version).
168
164
 
169
165
  Args:
170
166
  mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
@@ -203,8 +199,7 @@ class KalmanFilterXYAH:
203
199
  return mean, covariance
204
200
 
205
201
  def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
206
- """
207
- Run Kalman filter correction step.
202
+ """Run Kalman filter correction step.
208
203
 
209
204
  Args:
210
205
  mean (np.ndarray): The predicted state's mean vector (8 dimensional).
@@ -243,8 +238,7 @@ class KalmanFilterXYAH:
243
238
  only_position: bool = False,
244
239
  metric: str = "maha",
245
240
  ) -> np.ndarray:
246
- """
247
- Compute gating distance between state distribution and measurements.
241
+ """Compute gating distance between state distribution and measurements.
248
242
 
249
243
  A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
250
244
  distribution has 4 degrees of freedom, otherwise 2.
@@ -252,11 +246,12 @@ class KalmanFilterXYAH:
252
246
  Args:
253
247
  mean (np.ndarray): Mean vector over the state distribution (8 dimensional).
254
248
  covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
255
- measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
256
- bounding box center position, a the aspect ratio, and h the height.
257
- only_position (bool, optional): If True, distance computation is done with respect to box center position only.
258
- metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
259
- Euclidean distance and 'maha' for the squared Mahalanobis distance.
249
+ measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is
250
+ the bounding box center position, a the aspect ratio, and h the height.
251
+ only_position (bool, optional): If True, distance computation is done with respect to box center position
252
+ only.
253
+ metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
254
+ squared Euclidean distance and 'maha' for the squared Mahalanobis distance.
260
255
 
261
256
  Returns:
262
257
  (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
@@ -287,12 +282,11 @@ class KalmanFilterXYAH:
287
282
 
288
283
 
289
284
  class KalmanFilterXYWH(KalmanFilterXYAH):
290
- """
291
- A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
285
+ """A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
292
286
 
293
- Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
294
- (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
295
- The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
287
+ Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where (x, y)
288
+ is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. The
289
+ object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
296
290
  observation of the state space (linear observation model).
297
291
 
298
292
  Attributes:
@@ -318,14 +312,15 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
318
312
  """
319
313
 
320
314
  def initiate(self, measurement: np.ndarray):
321
- """
322
- Create track from unassociated measurement.
315
+ """Create track from unassociated measurement.
323
316
 
324
317
  Args:
325
- measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
318
+ measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and
319
+ height.
326
320
 
327
321
  Returns:
328
- mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
322
+ mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0
323
+ mean.
329
324
  covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
330
325
 
331
326
  Examples:
@@ -362,12 +357,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
362
357
  return mean, covariance
363
358
 
364
359
  def predict(self, mean: np.ndarray, covariance: np.ndarray):
365
- """
366
- Run Kalman filter prediction step.
360
+ """Run Kalman filter prediction step.
367
361
 
368
362
  Args:
369
363
  mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
370
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
364
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
365
+ step.
371
366
 
372
367
  Returns:
373
368
  mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
@@ -399,8 +394,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
399
394
  return mean, covariance
400
395
 
401
396
  def project(self, mean: np.ndarray, covariance: np.ndarray):
402
- """
403
- Project state distribution to measurement space.
397
+ """Project state distribution to measurement space.
404
398
 
405
399
  Args:
406
400
  mean (np.ndarray): The state's mean vector (8 dimensional array).
@@ -429,8 +423,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
429
423
  return mean, covariance + innovation_cov
430
424
 
431
425
  def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
432
- """
433
- Run Kalman filter prediction step (Vectorized version).
426
+ """Run Kalman filter prediction step (Vectorized version).
434
427
 
435
428
  Args:
436
429
  mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
@@ -470,8 +463,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
470
463
  return mean, covariance
471
464
 
472
465
  def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
473
- """
474
- Run Kalman filter correction step.
466
+ """Run Kalman filter correction step.
475
467
 
476
468
  Args:
477
469
  mean (np.ndarray): The predicted state's mean vector (8 dimensional).
@@ -18,8 +18,7 @@ except (ImportError, AssertionError, AttributeError):
18
18
 
19
19
 
20
20
  def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
21
- """
22
- Perform linear assignment using either the scipy or lap.lapjv method.
21
+ """Perform linear assignment using either the scipy or lap.lapjv method.
23
22
 
24
23
  Args:
25
24
  cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
@@ -62,8 +61,7 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
62
61
 
63
62
 
64
63
  def iou_distance(atracks: list, btracks: list) -> np.ndarray:
65
- """
66
- Compute cost based on Intersection over Union (IoU) between tracks.
64
+ """Compute cost based on Intersection over Union (IoU) between tracks.
67
65
 
68
66
  Args:
69
67
  atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
@@ -102,8 +100,7 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
102
100
 
103
101
 
104
102
  def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
105
- """
106
- Compute distance between tracks and detections based on embeddings.
103
+ """Compute distance between tracks and detections based on embeddings.
107
104
 
108
105
  Args:
109
106
  tracks (list[STrack]): List of tracks, where each track contains embedding features.
@@ -111,8 +108,8 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
111
108
  metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
112
109
 
113
110
  Returns:
114
- (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
115
- and M is the number of detections.
111
+ (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks and M
112
+ is the number of detections.
116
113
 
117
114
  Examples:
118
115
  Compute the embedding distance between tracks and detections using cosine metric
@@ -132,8 +129,7 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
132
129
 
133
130
 
134
131
  def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
135
- """
136
- Fuse cost matrix with detection scores to produce a single similarity matrix.
132
+ """Fuse cost matrix with detection scores to produce a single similarity matrix.
137
133
 
138
134
  Args:
139
135
  cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).