ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import numpy as np
4
4
  import scipy.linalg
@@ -6,17 +6,49 @@ import scipy.linalg
6
6
 
7
7
  class KalmanFilterXYAH:
8
8
  """
9
- For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
10
-
11
- The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
12
- ratio a, height h, and their respective velocities.
13
-
14
- Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
15
- observation of the state space (linear observation model).
9
+ A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
10
+
11
+ Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
12
+ (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
13
+ respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
14
+ taken as a direct observation of the state space (linear observation model).
15
+
16
+ Attributes:
17
+ _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
18
+ _update_mat (np.ndarray): The update matrix for the Kalman filter.
19
+ _std_weight_position (float): Standard deviation weight for position.
20
+ _std_weight_velocity (float): Standard deviation weight for velocity.
21
+
22
+ Methods:
23
+ initiate: Creates a track from an unassociated measurement.
24
+ predict: Runs the Kalman filter prediction step.
25
+ project: Projects the state distribution to measurement space.
26
+ multi_predict: Runs the Kalman filter prediction step (vectorized version).
27
+ update: Runs the Kalman filter correction step.
28
+ gating_distance: Computes the gating distance between state distribution and measurements.
29
+
30
+ Examples:
31
+ Initialize the Kalman filter and create a track from a measurement
32
+ >>> kf = KalmanFilterXYAH()
33
+ >>> measurement = np.array([100, 200, 1.5, 50])
34
+ >>> mean, covariance = kf.initiate(measurement)
35
+ >>> print(mean)
36
+ >>> print(covariance)
16
37
  """
17
38
 
18
39
  def __init__(self):
19
- """Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
40
+ """
41
+ Initialize Kalman filter model matrices with motion and observation uncertainty weights.
42
+
43
+ The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
44
+ represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
45
+ velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
46
+ observation model for bounding box location.
47
+
48
+ Examples:
49
+ Initialize a Kalman filter for tracking:
50
+ >>> kf = KalmanFilterXYAH()
51
+ """
20
52
  ndim, dt = 4, 1.0
21
53
 
22
54
  # Create Kalman filter model matrices
@@ -32,15 +64,20 @@ class KalmanFilterXYAH:
32
64
 
33
65
  def initiate(self, measurement: np.ndarray) -> tuple:
34
66
  """
35
- Create track from unassociated measurement.
67
+ Create a track from an unassociated measurement.
36
68
 
37
69
  Args:
38
70
  measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
39
71
  and height h.
40
72
 
41
73
  Returns:
42
- (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
43
- the new track. Unobserved velocities are initialized to 0 mean.
74
+ (tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional)
75
+ of the new track. Unobserved velocities are initialized to 0 mean.
76
+
77
+ Examples:
78
+ >>> kf = KalmanFilterXYAH()
79
+ >>> measurement = np.array([100, 50, 1.5, 200])
80
+ >>> mean, covariance = kf.initiate(measurement)
44
81
  """
45
82
  mean_pos = measurement
46
83
  mean_vel = np.zeros_like(mean_pos)
@@ -64,12 +101,18 @@ class KalmanFilterXYAH:
64
101
  Run Kalman filter prediction step.
65
102
 
66
103
  Args:
67
- mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
68
- covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
104
+ mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
105
+ covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
69
106
 
70
107
  Returns:
71
108
  (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
72
109
  velocities are initialized to 0 mean.
110
+
111
+ Examples:
112
+ >>> kf = KalmanFilterXYAH()
113
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
114
+ >>> covariance = np.eye(8)
115
+ >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
73
116
  """
74
117
  std_pos = [
75
118
  self._std_weight_position * mean[3],
@@ -100,6 +143,12 @@ class KalmanFilterXYAH:
100
143
 
101
144
  Returns:
102
145
  (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
146
+
147
+ Examples:
148
+ >>> kf = KalmanFilterXYAH()
149
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
150
+ >>> covariance = np.eye(8)
151
+ >>> projected_mean, projected_covariance = kf.project(mean, covariance)
103
152
  """
104
153
  std = [
105
154
  self._std_weight_position * mean[3],
@@ -115,15 +164,21 @@ class KalmanFilterXYAH:
115
164
 
116
165
  def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
117
166
  """
118
- Run Kalman filter prediction step (Vectorized version).
167
+ Run Kalman filter prediction step for multiple object states (Vectorized version).
119
168
 
120
169
  Args:
121
170
  mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
122
171
  covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
123
172
 
124
173
  Returns:
125
- (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
126
- velocities are initialized to 0 mean.
174
+ (tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states.
175
+ The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities
176
+ are initialized to 0 mean.
177
+
178
+ Examples:
179
+ >>> mean = np.random.rand(10, 8) # 10 object states
180
+ >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
181
+ >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
127
182
  """
128
183
  std_pos = [
129
184
  self._std_weight_position * mean[:, 3],
@@ -160,6 +215,13 @@ class KalmanFilterXYAH:
160
215
 
161
216
  Returns:
162
217
  (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
218
+
219
+ Examples:
220
+ >>> kf = KalmanFilterXYAH()
221
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
222
+ >>> covariance = np.eye(8)
223
+ >>> measurement = np.array([1, 1, 1, 1])
224
+ >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
163
225
  """
164
226
  projected_mean, projected_cov = self.project(mean, covariance)
165
227
 
@@ -182,23 +244,31 @@ class KalmanFilterXYAH:
182
244
  metric: str = "maha",
183
245
  ) -> np.ndarray:
184
246
  """
185
- Compute gating distance between state distribution and measurements. A suitable distance threshold can be
186
- obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
187
- otherwise 2.
247
+ Compute gating distance between state distribution and measurements.
248
+
249
+ A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
250
+ distribution has 4 degrees of freedom, otherwise 2.
188
251
 
189
252
  Args:
190
253
  mean (ndarray): Mean vector over the state distribution (8 dimensional).
191
254
  covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
192
- measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
193
- is the bounding box center position, a the aspect ratio, and h the height.
194
- only_position (bool, optional): If True, distance computation is done with respect to the bounding box
195
- center position only. Defaults to False.
196
- metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
197
- squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
255
+ measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
256
+ bounding box center position, a the aspect ratio, and h the height.
257
+ only_position (bool): If True, distance computation is done with respect to box center position only.
258
+ metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
259
+ Euclidean distance and 'maha' for the squared Mahalanobis distance.
198
260
 
199
261
  Returns:
200
262
  (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
201
263
  (mean, covariance) and `measurements[i]`.
264
+
265
+ Examples:
266
+ Compute gating distance using Mahalanobis metric:
267
+ >>> kf = KalmanFilterXYAH()
268
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
269
+ >>> covariance = np.eye(8)
270
+ >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
271
+ >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha")
202
272
  """
203
273
  mean, covariance = self.project(mean, covariance)
204
274
  if only_position:
@@ -218,13 +288,33 @@ class KalmanFilterXYAH:
218
288
 
219
289
  class KalmanFilterXYWH(KalmanFilterXYAH):
220
290
  """
221
- For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
222
-
223
- The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
224
- w, height h, and their respective velocities.
291
+ A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
225
292
 
226
- Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
293
+ Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
294
+ (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
295
+ The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
227
296
  observation of the state space (linear observation model).
297
+
298
+ Attributes:
299
+ _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
300
+ _update_mat (np.ndarray): The update matrix for the Kalman filter.
301
+ _std_weight_position (float): Standard deviation weight for position.
302
+ _std_weight_velocity (float): Standard deviation weight for velocity.
303
+
304
+ Methods:
305
+ initiate: Creates a track from an unassociated measurement.
306
+ predict: Runs the Kalman filter prediction step.
307
+ project: Projects the state distribution to measurement space.
308
+ multi_predict: Runs the Kalman filter prediction step in a vectorized manner.
309
+ update: Runs the Kalman filter correction step.
310
+
311
+ Examples:
312
+ Create a Kalman filter and initialize a track
313
+ >>> kf = KalmanFilterXYWH()
314
+ >>> measurement = np.array([100, 50, 20, 40])
315
+ >>> mean, covariance = kf.initiate(measurement)
316
+ >>> print(mean)
317
+ >>> print(covariance)
228
318
  """
229
319
 
230
320
  def initiate(self, measurement: np.ndarray) -> tuple:
@@ -235,8 +325,24 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
235
325
  measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
236
326
 
237
327
  Returns:
238
- (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of
239
- the new track. Unobserved velocities are initialized to 0 mean.
328
+ (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
329
+ of the new track. Unobserved velocities are initialized to 0 mean.
330
+
331
+ Examples:
332
+ >>> kf = KalmanFilterXYWH()
333
+ >>> measurement = np.array([100, 50, 20, 40])
334
+ >>> mean, covariance = kf.initiate(measurement)
335
+ >>> print(mean)
336
+ [100. 50. 20. 40. 0. 0. 0. 0.]
337
+ >>> print(covariance)
338
+ [[ 4. 0. 0. 0. 0. 0. 0. 0.]
339
+ [ 0. 4. 0. 0. 0. 0. 0. 0.]
340
+ [ 0. 0. 4. 0. 0. 0. 0. 0.]
341
+ [ 0. 0. 0. 4. 0. 0. 0. 0.]
342
+ [ 0. 0. 0. 0. 0.25 0. 0. 0.]
343
+ [ 0. 0. 0. 0. 0. 0.25 0. 0.]
344
+ [ 0. 0. 0. 0. 0. 0. 0.25 0.]
345
+ [ 0. 0. 0. 0. 0. 0. 0. 0.25]]
240
346
  """
241
347
  mean_pos = measurement
242
348
  mean_vel = np.zeros_like(mean_pos)
@@ -260,12 +366,18 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
260
366
  Run Kalman filter prediction step.
261
367
 
262
368
  Args:
263
- mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
264
- covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
369
+ mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
370
+ covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
265
371
 
266
372
  Returns:
267
373
  (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
268
374
  velocities are initialized to 0 mean.
375
+
376
+ Examples:
377
+ >>> kf = KalmanFilterXYWH()
378
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
379
+ >>> covariance = np.eye(8)
380
+ >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
269
381
  """
270
382
  std_pos = [
271
383
  self._std_weight_position * mean[2],
@@ -296,6 +408,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
296
408
 
297
409
  Returns:
298
410
  (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
411
+
412
+ Examples:
413
+ >>> kf = KalmanFilterXYWH()
414
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
415
+ >>> covariance = np.eye(8)
416
+ >>> projected_mean, projected_cov = kf.project(mean, covariance)
299
417
  """
300
418
  std = [
301
419
  self._std_weight_position * mean[2],
@@ -320,6 +438,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
320
438
  Returns:
321
439
  (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
322
440
  velocities are initialized to 0 mean.
441
+
442
+ Examples:
443
+ >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
444
+ >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
445
+ >>> kf = KalmanFilterXYWH()
446
+ >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
323
447
  """
324
448
  std_pos = [
325
449
  self._std_weight_position * mean[:, 2],
@@ -356,5 +480,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
356
480
 
357
481
  Returns:
358
482
  (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
483
+
484
+ Examples:
485
+ >>> kf = KalmanFilterXYWH()
486
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
487
+ >>> covariance = np.eye(8)
488
+ >>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
489
+ >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
359
490
  """
360
491
  return super().update(mean, covariance, measurement)
@@ -1,10 +1,10 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import numpy as np
4
4
  import scipy
5
5
  from scipy.spatial.distance import cdist
6
6
 
7
- from ultralytics.utils.metrics import bbox_ioa, batch_probiou
7
+ from ultralytics.utils.metrics import batch_probiou, bbox_ioa
8
8
 
9
9
  try:
10
10
  import lap # for linear_assignment
@@ -13,26 +13,29 @@ try:
13
13
  except (ImportError, AssertionError, AttributeError):
14
14
  from ultralytics.utils.checks import check_requirements
15
15
 
16
- check_requirements("lapx>=0.5.2") # update to lap package from https://github.com/rathaROG/lapx
16
+ check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap
17
17
  import lap
18
18
 
19
19
 
20
20
  def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
21
21
  """
22
- Perform linear assignment using scipy or lap.lapjv.
22
+ Perform linear assignment using either the scipy or lap.lapjv method.
23
23
 
24
24
  Args:
25
- cost_matrix (np.ndarray): The matrix containing cost values for assignments.
25
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
26
26
  thresh (float): Threshold for considering an assignment valid.
27
- use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
27
+ use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
28
28
 
29
29
  Returns:
30
- Tuple with:
31
- - matched indices
32
- - unmatched indices from 'a'
33
- - unmatched indices from 'b'
30
+ matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
31
+ unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
32
+ unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
33
+
34
+ Examples:
35
+ >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
36
+ >>> thresh = 5.0
37
+ >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
34
38
  """
35
-
36
39
  if cost_matrix.size == 0:
37
40
  return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
38
41
 
@@ -68,8 +71,13 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
68
71
 
69
72
  Returns:
70
73
  (np.ndarray): Cost matrix computed based on IoU.
71
- """
72
74
 
75
+ Examples:
76
+ Compute IoU distance between two sets of tracks
77
+ >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
78
+ >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
79
+ >>> cost_matrix = iou_distance(atracks, btracks)
80
+ """
73
81
  if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
74
82
  atlbrs = atracks
75
83
  btlbrs = btracks
@@ -98,14 +106,20 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
98
106
  Compute distance between tracks and detections based on embeddings.
99
107
 
100
108
  Args:
101
- tracks (list[STrack]): List of tracks.
102
- detections (list[BaseTrack]): List of detections.
103
- metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
109
+ tracks (list[STrack]): List of tracks, where each track contains embedding features.
110
+ detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
111
+ metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
104
112
 
105
113
  Returns:
106
- (np.ndarray): Cost matrix computed based on embeddings.
114
+ (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
115
+ and M is the number of detections.
116
+
117
+ Examples:
118
+ Compute the embedding distance between tracks and detections using cosine metric
119
+ >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
120
+ >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
121
+ >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
107
122
  """
108
-
109
123
  cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
110
124
  if cost_matrix.size == 0:
111
125
  return cost_matrix
@@ -122,13 +136,18 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
122
136
  Fuses cost matrix with detection scores to produce a single similarity matrix.
123
137
 
124
138
  Args:
125
- cost_matrix (np.ndarray): The matrix containing cost values for assignments.
126
- detections (list[BaseTrack]): List of detections with scores.
139
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
140
+ detections (list[BaseTrack]): List of detections, each containing a score attribute.
127
141
 
128
142
  Returns:
129
- (np.ndarray): Fused similarity matrix.
130
- """
143
+ (np.ndarray): Fused similarity matrix with shape (N, M).
131
144
 
145
+ Examples:
146
+ Fuse a cost matrix with detection scores
147
+ >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
148
+ >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
149
+ >>> fused_matrix = fuse_score(cost_matrix, detections)
150
+ """
132
151
  if cost_matrix.size == 0:
133
152
  return cost_matrix
134
153
  iou_sim = 1 - cost_matrix