ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,482 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import numpy as np
4
+ import scipy.linalg
5
+
6
+
7
+ class KalmanFilterXYAH:
8
+ """A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
9
+
10
+ Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space (x, y,
11
+ a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
12
+ respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
13
+ taken as a direct observation of the state space (linear observation model).
14
+
15
+ Attributes:
16
+ _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
17
+ _update_mat (np.ndarray): The update matrix for the Kalman filter.
18
+ _std_weight_position (float): Standard deviation weight for position.
19
+ _std_weight_velocity (float): Standard deviation weight for velocity.
20
+
21
+ Methods:
22
+ initiate: Create a track from an unassociated measurement.
23
+ predict: Run the Kalman filter prediction step.
24
+ project: Project the state distribution to measurement space.
25
+ multi_predict: Run the Kalman filter prediction step (vectorized version).
26
+ update: Run the Kalman filter correction step.
27
+ gating_distance: Compute the gating distance between state distribution and measurements.
28
+
29
+ Examples:
30
+ Initialize the Kalman filter and create a track from a measurement
31
+ >>> kf = KalmanFilterXYAH()
32
+ >>> measurement = np.array([100, 200, 1.5, 50])
33
+ >>> mean, covariance = kf.initiate(measurement)
34
+ >>> print(mean)
35
+ >>> print(covariance)
36
+ """
37
+
38
+ def __init__(self):
39
+ """Initialize Kalman filter model matrices with motion and observation uncertainty weights.
40
+
41
+ The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
42
+ represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
43
+ velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
44
+ observation model for bounding box location.
45
+ """
46
+ ndim, dt = 4, 1.0
47
+
48
+ # Create Kalman filter model matrices
49
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
50
+ for i in range(ndim):
51
+ self._motion_mat[i, ndim + i] = dt
52
+ self._update_mat = np.eye(ndim, 2 * ndim)
53
+
54
+ # Motion and observation uncertainty are chosen relative to the current state estimate
55
+ self._std_weight_position = 1.0 / 20
56
+ self._std_weight_velocity = 1.0 / 160
57
+
58
+ def initiate(self, measurement: np.ndarray):
59
+ """Create a track from an unassociated measurement.
60
+
61
+ Args:
62
+ measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
63
+ and height h.
64
+
65
+ Returns:
66
+ mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0
67
+ mean.
68
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
69
+
70
+ Examples:
71
+ >>> kf = KalmanFilterXYAH()
72
+ >>> measurement = np.array([100, 50, 1.5, 200])
73
+ >>> mean, covariance = kf.initiate(measurement)
74
+ """
75
+ mean_pos = measurement
76
+ mean_vel = np.zeros_like(mean_pos)
77
+ mean = np.r_[mean_pos, mean_vel]
78
+
79
+ std = [
80
+ 2 * self._std_weight_position * measurement[3],
81
+ 2 * self._std_weight_position * measurement[3],
82
+ 1e-2,
83
+ 2 * self._std_weight_position * measurement[3],
84
+ 10 * self._std_weight_velocity * measurement[3],
85
+ 10 * self._std_weight_velocity * measurement[3],
86
+ 1e-5,
87
+ 10 * self._std_weight_velocity * measurement[3],
88
+ ]
89
+ covariance = np.diag(np.square(std))
90
+ return mean, covariance
91
+
92
+ def predict(self, mean: np.ndarray, covariance: np.ndarray):
93
+ """Run Kalman filter prediction step.
94
+
95
+ Args:
96
+ mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
97
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
98
+ step.
99
+
100
+ Returns:
101
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
102
+ covariance (np.ndarray): Covariance matrix of the predicted state.
103
+
104
+ Examples:
105
+ >>> kf = KalmanFilterXYAH()
106
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
107
+ >>> covariance = np.eye(8)
108
+ >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
109
+ """
110
+ std_pos = [
111
+ self._std_weight_position * mean[3],
112
+ self._std_weight_position * mean[3],
113
+ 1e-2,
114
+ self._std_weight_position * mean[3],
115
+ ]
116
+ std_vel = [
117
+ self._std_weight_velocity * mean[3],
118
+ self._std_weight_velocity * mean[3],
119
+ 1e-5,
120
+ self._std_weight_velocity * mean[3],
121
+ ]
122
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
123
+
124
+ mean = np.dot(mean, self._motion_mat.T)
125
+ covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
126
+
127
+ return mean, covariance
128
+
129
+ def project(self, mean: np.ndarray, covariance: np.ndarray):
130
+ """Project state distribution to measurement space.
131
+
132
+ Args:
133
+ mean (np.ndarray): The state's mean vector (8 dimensional array).
134
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
135
+
136
+ Returns:
137
+ mean (np.ndarray): Projected mean of the given state estimate.
138
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
139
+
140
+ Examples:
141
+ >>> kf = KalmanFilterXYAH()
142
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
143
+ >>> covariance = np.eye(8)
144
+ >>> projected_mean, projected_covariance = kf.project(mean, covariance)
145
+ """
146
+ std = [
147
+ self._std_weight_position * mean[3],
148
+ self._std_weight_position * mean[3],
149
+ 1e-1,
150
+ self._std_weight_position * mean[3],
151
+ ]
152
+ innovation_cov = np.diag(np.square(std))
153
+
154
+ mean = np.dot(self._update_mat, mean)
155
+ covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
156
+ return mean, covariance + innovation_cov
157
+
158
+ def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
159
+ """Run Kalman filter prediction step for multiple object states (Vectorized version).
160
+
161
+ Args:
162
+ mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
163
+ covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
164
+
165
+ Returns:
166
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
167
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
168
+
169
+ Examples:
170
+ >>> kf = KalmanFilterXYAH()
171
+ >>> mean = np.random.rand(10, 8) # 10 object states
172
+ >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
173
+ >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
174
+ """
175
+ std_pos = [
176
+ self._std_weight_position * mean[:, 3],
177
+ self._std_weight_position * mean[:, 3],
178
+ 1e-2 * np.ones_like(mean[:, 3]),
179
+ self._std_weight_position * mean[:, 3],
180
+ ]
181
+ std_vel = [
182
+ self._std_weight_velocity * mean[:, 3],
183
+ self._std_weight_velocity * mean[:, 3],
184
+ 1e-5 * np.ones_like(mean[:, 3]),
185
+ self._std_weight_velocity * mean[:, 3],
186
+ ]
187
+ sqr = np.square(np.r_[std_pos, std_vel]).T
188
+
189
+ motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
190
+ motion_cov = np.asarray(motion_cov)
191
+
192
+ mean = np.dot(mean, self._motion_mat.T)
193
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
194
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
195
+
196
+ return mean, covariance
197
+
198
+ def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
199
+ """Run Kalman filter correction step.
200
+
201
+ Args:
202
+ mean (np.ndarray): The predicted state's mean vector (8 dimensional).
203
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
204
+ measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
205
+ position, a the aspect ratio, and h the height of the bounding box.
206
+
207
+ Returns:
208
+ new_mean (np.ndarray): Measurement-corrected state mean.
209
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
210
+
211
+ Examples:
212
+ >>> kf = KalmanFilterXYAH()
213
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
214
+ >>> covariance = np.eye(8)
215
+ >>> measurement = np.array([1, 1, 1, 1])
216
+ >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
217
+ """
218
+ projected_mean, projected_cov = self.project(mean, covariance)
219
+
220
+ chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
221
+ kalman_gain = scipy.linalg.cho_solve(
222
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
223
+ ).T
224
+ innovation = measurement - projected_mean
225
+
226
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
227
+ new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
228
+ return new_mean, new_covariance
229
+
230
+ def gating_distance(
231
+ self,
232
+ mean: np.ndarray,
233
+ covariance: np.ndarray,
234
+ measurements: np.ndarray,
235
+ only_position: bool = False,
236
+ metric: str = "maha",
237
+ ) -> np.ndarray:
238
+ """Compute gating distance between state distribution and measurements.
239
+
240
+ A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
241
+ distribution has 4 degrees of freedom, otherwise 2.
242
+
243
+ Args:
244
+ mean (np.ndarray): Mean vector over the state distribution (8 dimensional).
245
+ covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
246
+ measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is
247
+ the bounding box center position, a the aspect ratio, and h the height.
248
+ only_position (bool, optional): If True, distance computation is done with respect to box center position
249
+ only.
250
+ metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
251
+ squared Euclidean distance and 'maha' for the squared Mahalanobis distance.
252
+
253
+ Returns:
254
+ (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
255
+ (mean, covariance) and `measurements[i]`.
256
+
257
+ Examples:
258
+ Compute gating distance using Mahalanobis metric:
259
+ >>> kf = KalmanFilterXYAH()
260
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
261
+ >>> covariance = np.eye(8)
262
+ >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
263
+ >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha")
264
+ """
265
+ mean, covariance = self.project(mean, covariance)
266
+ if only_position:
267
+ mean, covariance = mean[:2], covariance[:2, :2]
268
+ measurements = measurements[:, :2]
269
+
270
+ d = measurements - mean
271
+ if metric == "gaussian":
272
+ return np.sum(d * d, axis=1)
273
+ elif metric == "maha":
274
+ cholesky_factor = np.linalg.cholesky(covariance)
275
+ z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
276
+ return np.sum(z * z, axis=0) # square maha
277
+ else:
278
+ raise ValueError("Invalid distance metric")
279
+
280
+
281
+ class KalmanFilterXYWH(KalmanFilterXYAH):
282
+ """A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
283
+
284
+ Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where (x, y)
285
+ is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. The
286
+ object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
287
+ observation of the state space (linear observation model).
288
+
289
+ Attributes:
290
+ _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
291
+ _update_mat (np.ndarray): The update matrix for the Kalman filter.
292
+ _std_weight_position (float): Standard deviation weight for position.
293
+ _std_weight_velocity (float): Standard deviation weight for velocity.
294
+
295
+ Methods:
296
+ initiate: Create a track from an unassociated measurement.
297
+ predict: Run the Kalman filter prediction step.
298
+ project: Project the state distribution to measurement space.
299
+ multi_predict: Run the Kalman filter prediction step in a vectorized manner.
300
+ update: Run the Kalman filter correction step.
301
+
302
+ Examples:
303
+ Create a Kalman filter and initialize a track
304
+ >>> kf = KalmanFilterXYWH()
305
+ >>> measurement = np.array([100, 50, 20, 40])
306
+ >>> mean, covariance = kf.initiate(measurement)
307
+ >>> print(mean)
308
+ >>> print(covariance)
309
+ """
310
+
311
+ def initiate(self, measurement: np.ndarray):
312
+ """Create track from unassociated measurement.
313
+
314
+ Args:
315
+ measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and
316
+ height.
317
+
318
+ Returns:
319
+ mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0
320
+ mean.
321
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
322
+
323
+ Examples:
324
+ >>> kf = KalmanFilterXYWH()
325
+ >>> measurement = np.array([100, 50, 20, 40])
326
+ >>> mean, covariance = kf.initiate(measurement)
327
+ >>> print(mean)
328
+ [100. 50. 20. 40. 0. 0. 0. 0.]
329
+ >>> print(covariance)
330
+ [[ 4. 0. 0. 0. 0. 0. 0. 0.]
331
+ [ 0. 4. 0. 0. 0. 0. 0. 0.]
332
+ [ 0. 0. 4. 0. 0. 0. 0. 0.]
333
+ [ 0. 0. 0. 4. 0. 0. 0. 0.]
334
+ [ 0. 0. 0. 0. 0.25 0. 0. 0.]
335
+ [ 0. 0. 0. 0. 0. 0.25 0. 0.]
336
+ [ 0. 0. 0. 0. 0. 0. 0.25 0.]
337
+ [ 0. 0. 0. 0. 0. 0. 0. 0.25]]
338
+ """
339
+ mean_pos = measurement
340
+ mean_vel = np.zeros_like(mean_pos)
341
+ mean = np.r_[mean_pos, mean_vel]
342
+
343
+ std = [
344
+ 2 * self._std_weight_position * measurement[2],
345
+ 2 * self._std_weight_position * measurement[3],
346
+ 2 * self._std_weight_position * measurement[2],
347
+ 2 * self._std_weight_position * measurement[3],
348
+ 10 * self._std_weight_velocity * measurement[2],
349
+ 10 * self._std_weight_velocity * measurement[3],
350
+ 10 * self._std_weight_velocity * measurement[2],
351
+ 10 * self._std_weight_velocity * measurement[3],
352
+ ]
353
+ covariance = np.diag(np.square(std))
354
+ return mean, covariance
355
+
356
+ def predict(self, mean: np.ndarray, covariance: np.ndarray):
357
+ """Run Kalman filter prediction step.
358
+
359
+ Args:
360
+ mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
361
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
362
+ step.
363
+
364
+ Returns:
365
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
366
+ covariance (np.ndarray): Covariance matrix of the predicted state.
367
+
368
+ Examples:
369
+ >>> kf = KalmanFilterXYWH()
370
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
371
+ >>> covariance = np.eye(8)
372
+ >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
373
+ """
374
+ std_pos = [
375
+ self._std_weight_position * mean[2],
376
+ self._std_weight_position * mean[3],
377
+ self._std_weight_position * mean[2],
378
+ self._std_weight_position * mean[3],
379
+ ]
380
+ std_vel = [
381
+ self._std_weight_velocity * mean[2],
382
+ self._std_weight_velocity * mean[3],
383
+ self._std_weight_velocity * mean[2],
384
+ self._std_weight_velocity * mean[3],
385
+ ]
386
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
387
+
388
+ mean = np.dot(mean, self._motion_mat.T)
389
+ covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
390
+
391
+ return mean, covariance
392
+
393
+ def project(self, mean: np.ndarray, covariance: np.ndarray):
394
+ """Project state distribution to measurement space.
395
+
396
+ Args:
397
+ mean (np.ndarray): The state's mean vector (8 dimensional array).
398
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
399
+
400
+ Returns:
401
+ mean (np.ndarray): Projected mean of the given state estimate.
402
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
403
+
404
+ Examples:
405
+ >>> kf = KalmanFilterXYWH()
406
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
407
+ >>> covariance = np.eye(8)
408
+ >>> projected_mean, projected_cov = kf.project(mean, covariance)
409
+ """
410
+ std = [
411
+ self._std_weight_position * mean[2],
412
+ self._std_weight_position * mean[3],
413
+ self._std_weight_position * mean[2],
414
+ self._std_weight_position * mean[3],
415
+ ]
416
+ innovation_cov = np.diag(np.square(std))
417
+
418
+ mean = np.dot(self._update_mat, mean)
419
+ covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
420
+ return mean, covariance + innovation_cov
421
+
422
+ def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
423
+ """Run Kalman filter prediction step (Vectorized version).
424
+
425
+ Args:
426
+ mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
427
+ covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
428
+
429
+ Returns:
430
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
431
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
432
+
433
+ Examples:
434
+ >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
435
+ >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
436
+ >>> kf = KalmanFilterXYWH()
437
+ >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
438
+ """
439
+ std_pos = [
440
+ self._std_weight_position * mean[:, 2],
441
+ self._std_weight_position * mean[:, 3],
442
+ self._std_weight_position * mean[:, 2],
443
+ self._std_weight_position * mean[:, 3],
444
+ ]
445
+ std_vel = [
446
+ self._std_weight_velocity * mean[:, 2],
447
+ self._std_weight_velocity * mean[:, 3],
448
+ self._std_weight_velocity * mean[:, 2],
449
+ self._std_weight_velocity * mean[:, 3],
450
+ ]
451
+ sqr = np.square(np.r_[std_pos, std_vel]).T
452
+
453
+ motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
454
+ motion_cov = np.asarray(motion_cov)
455
+
456
+ mean = np.dot(mean, self._motion_mat.T)
457
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
458
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
459
+
460
+ return mean, covariance
461
+
462
+ def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
463
+ """Run Kalman filter correction step.
464
+
465
+ Args:
466
+ mean (np.ndarray): The predicted state's mean vector (8 dimensional).
467
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
468
+ measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
469
+ position, w the width, and h the height of the bounding box.
470
+
471
+ Returns:
472
+ new_mean (np.ndarray): Measurement-corrected state mean.
473
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
474
+
475
+ Examples:
476
+ >>> kf = KalmanFilterXYWH()
477
+ >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
478
+ >>> covariance = np.eye(8)
479
+ >>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
480
+ >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
481
+ """
482
+ return super().update(mean, covariance, measurement)
@@ -0,0 +1,154 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import numpy as np
4
+ import scipy
5
+ from scipy.spatial.distance import cdist
6
+
7
+ from ultralytics.utils.metrics import batch_probiou, bbox_ioa
8
+
9
+ try:
10
+ import lap # for linear_assignment
11
+
12
+ assert lap.__version__ # verify package is not directory
13
+ except (ImportError, AssertionError, AttributeError):
14
+ from ultralytics.utils.checks import check_requirements
15
+
16
+ check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap
17
+ import lap
18
+
19
+
20
+ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
21
+ """Perform linear assignment using either the scipy or lap.lapjv method.
22
+
23
+ Args:
24
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
25
+ thresh (float): Threshold for considering an assignment valid.
26
+ use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
27
+
28
+ Returns:
29
+ matched_indices (list[list[int]] | np.ndarray): Matched indices of shape (K, 2), where K is the number of
30
+ matches.
31
+ unmatched_a (np.ndarray): Unmatched indices from the first set, with shape (L,).
32
+ unmatched_b (np.ndarray): Unmatched indices from the second set, with shape (M,).
33
+
34
+ Examples:
35
+ >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
36
+ >>> thresh = 5.0
37
+ >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
38
+ """
39
+ if cost_matrix.size == 0:
40
+ return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
41
+
42
+ if use_lap:
43
+ # Use lap.lapjv
44
+ # https://github.com/gatagat/lap
45
+ _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
46
+ matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
47
+ unmatched_a = np.where(x < 0)[0]
48
+ unmatched_b = np.where(y < 0)[0]
49
+ else:
50
+ # Use scipy.optimize.linear_sum_assignment
51
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
52
+ x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
53
+ matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
54
+ if len(matches) == 0:
55
+ unmatched_a = list(np.arange(cost_matrix.shape[0]))
56
+ unmatched_b = list(np.arange(cost_matrix.shape[1]))
57
+ else:
58
+ unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
59
+ unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
60
+
61
+ return matches, unmatched_a, unmatched_b
62
+
63
+
64
+ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
65
+ """Compute cost based on Intersection over Union (IoU) between tracks.
66
+
67
+ Args:
68
+ atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
69
+ btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes.
70
+
71
+ Returns:
72
+ (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).
73
+
74
+ Examples:
75
+ Compute IoU distance between two sets of tracks
76
+ >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
77
+ >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
78
+ >>> cost_matrix = iou_distance(atracks, btracks)
79
+ """
80
+ if (atracks and isinstance(atracks[0], np.ndarray)) or (btracks and isinstance(btracks[0], np.ndarray)):
81
+ atlbrs = atracks
82
+ btlbrs = btracks
83
+ else:
84
+ atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
85
+ btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
86
+
87
+ ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
88
+ if len(atlbrs) and len(btlbrs):
89
+ if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
90
+ ious = batch_probiou(
91
+ np.ascontiguousarray(atlbrs, dtype=np.float32),
92
+ np.ascontiguousarray(btlbrs, dtype=np.float32),
93
+ ).numpy()
94
+ else:
95
+ ious = bbox_ioa(
96
+ np.ascontiguousarray(atlbrs, dtype=np.float32),
97
+ np.ascontiguousarray(btlbrs, dtype=np.float32),
98
+ iou=True,
99
+ )
100
+ return 1 - ious # cost matrix
101
+
102
+
103
+ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
104
+ """Compute distance between tracks and detections based on embeddings.
105
+
106
+ Args:
107
+ tracks (list[STrack]): List of tracks, where each track contains embedding features.
108
+ detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
109
+ metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
110
+
111
+ Returns:
112
+ (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks and M
113
+ is the number of detections.
114
+
115
+ Examples:
116
+ Compute the embedding distance between tracks and detections using cosine metric
117
+ >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
118
+ >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
119
+ >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
120
+ """
121
+ cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
122
+ if cost_matrix.size == 0:
123
+ return cost_matrix
124
+ det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
125
+ # for i, track in enumerate(tracks):
126
+ # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
127
+ track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
128
+ cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
129
+ return cost_matrix
130
+
131
+
132
+ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
133
+ """Fuse cost matrix with detection scores to produce a single similarity matrix.
134
+
135
+ Args:
136
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
137
+ detections (list[BaseTrack]): List of detections, each containing a score attribute.
138
+
139
+ Returns:
140
+ (np.ndarray): Fused similarity matrix with shape (N, M).
141
+
142
+ Examples:
143
+ Fuse a cost matrix with detection scores
144
+ >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
145
+ >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
146
+ >>> fused_matrix = fuse_score(cost_matrix, detections)
147
+ """
148
+ if cost_matrix.size == 0:
149
+ return cost_matrix
150
+ iou_sim = 1 - cost_matrix
151
+ det_scores = np.array([det.score for det in detections])
152
+ det_scores = det_scores[None].repeat(cost_matrix.shape[0], axis=0)
153
+ fuse_sim = iou_sim * det_scores
154
+ return 1 - fuse_sim # fuse_cost