ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import numpy as np
4
4
 
5
+ from ..utils import LOGGER
6
+ from ..utils.ops import xywh2ltwh
5
7
  from .basetrack import BaseTrack, TrackState
6
8
  from .utils import matching
7
9
  from .utils.kalman_filter import KalmanFilterXYAH
8
- from ..utils.ops import xywh2ltwh
9
- from ..utils import LOGGER
10
10
 
11
11
 
12
12
  class STrack(BaseTrack):
@@ -25,7 +25,7 @@ class STrack(BaseTrack):
25
25
  is_activated (bool): Boolean flag indicating if the track has been activated.
26
26
  score (float): Confidence score of the track.
27
27
  tracklet_len (int): Length of the tracklet.
28
- cls (any): Class label for the object.
28
+ cls (Any): Class label for the object.
29
29
  idx (int): Index or identifier for the object.
30
30
  frame_id (int): Current frame ID.
31
31
  start_frame (int): Frame where the object was first detected.
@@ -39,15 +39,34 @@ class STrack(BaseTrack):
39
39
  update(new_track, frame_id): Update the state of a matched track.
40
40
  convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
41
41
  tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
42
+
43
+ Examples:
44
+ Initialize and activate a new track
45
+ >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
46
+ >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
42
47
  """
43
48
 
44
49
  shared_kalman = KalmanFilterXYAH()
45
50
 
46
51
  def __init__(self, xywh, score, cls):
47
- """Initialize new STrack instance."""
52
+ """
53
+ Initialize a new STrack instance.
54
+
55
+ Args:
56
+ xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
57
+ (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
58
+ score (float): Confidence score of the detection.
59
+ cls (Any): Class label for the detected object.
60
+
61
+ Examples:
62
+ >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
63
+ >>> score = 0.9
64
+ >>> cls = "person"
65
+ >>> track = STrack(xywh, score, cls)
66
+ """
48
67
  super().__init__()
49
68
  # xywh+idx or xywha+idx
50
- assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
69
+ assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
51
70
  self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
52
71
  self.kalman_filter = None
53
72
  self.mean, self.covariance = None, None
@@ -60,7 +79,7 @@ class STrack(BaseTrack):
60
79
  self.angle = xywh[4] if len(xywh) == 6 else None
61
80
 
62
81
  def predict(self):
63
- """Predicts mean and covariance using Kalman filter."""
82
+ """Predicts the next state (mean and covariance) of the object using the Kalman filter."""
64
83
  mean_state = self.mean.copy()
65
84
  if self.state != TrackState.Tracked:
66
85
  mean_state[7] = 0
@@ -68,7 +87,7 @@ class STrack(BaseTrack):
68
87
 
69
88
  @staticmethod
70
89
  def multi_predict(stracks):
71
- """Perform multi-object predictive tracking using Kalman filter for given stracks."""
90
+ """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
72
91
  if len(stracks) <= 0:
73
92
  return
74
93
  multi_mean = np.asarray([st.mean.copy() for st in stracks])
@@ -83,7 +102,7 @@ class STrack(BaseTrack):
83
102
 
84
103
  @staticmethod
85
104
  def multi_gmc(stracks, H=np.eye(2, 3)):
86
- """Update state tracks positions and covariances using a homography matrix."""
105
+ """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
87
106
  if len(stracks) > 0:
88
107
  multi_mean = np.asarray([st.mean.copy() for st in stracks])
89
108
  multi_covariance = np.asarray([st.covariance for st in stracks])
@@ -101,7 +120,7 @@ class STrack(BaseTrack):
101
120
  stracks[i].covariance = cov
102
121
 
103
122
  def activate(self, kalman_filter, frame_id):
104
- """Start a new tracklet."""
123
+ """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
105
124
  self.kalman_filter = kalman_filter
106
125
  self.track_id = self.next_id()
107
126
  self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
@@ -114,7 +133,7 @@ class STrack(BaseTrack):
114
133
  self.start_frame = frame_id
115
134
 
116
135
  def re_activate(self, new_track, frame_id, new_id=False):
117
- """Reactivates a previously lost track with a new detection."""
136
+ """Reactivates a previously lost track using new detection data and updates its state and attributes."""
118
137
  self.mean, self.covariance = self.kalman_filter.update(
119
138
  self.mean, self.covariance, self.convert_coords(new_track.tlwh)
120
139
  )
@@ -136,6 +155,12 @@ class STrack(BaseTrack):
136
155
  Args:
137
156
  new_track (STrack): The new track containing updated information.
138
157
  frame_id (int): The ID of the current frame.
158
+
159
+ Examples:
160
+ Update the state of a track with new detection information
161
+ >>> track = STrack([100, 200, 50, 80, 0.9, 1])
162
+ >>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
163
+ >>> track.update(new_track, 2)
139
164
  """
140
165
  self.frame_id = frame_id
141
166
  self.tracklet_len += 1
@@ -158,7 +183,7 @@ class STrack(BaseTrack):
158
183
 
159
184
  @property
160
185
  def tlwh(self):
161
- """Get current position in bounding box format (top left x, top left y, width, height)."""
186
+ """Returns the bounding box in top-left-width-height format from the current state estimate."""
162
187
  if self.mean is None:
163
188
  return self._tlwh.copy()
164
189
  ret = self.mean[:4].copy()
@@ -168,16 +193,14 @@ class STrack(BaseTrack):
168
193
 
169
194
  @property
170
195
  def xyxy(self):
171
- """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
196
+ """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
172
197
  ret = self.tlwh.copy()
173
198
  ret[2:] += ret[:2]
174
199
  return ret
175
200
 
176
201
  @staticmethod
177
202
  def tlwh_to_xyah(tlwh):
178
- """Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
179
- height.
180
- """
203
+ """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
181
204
  ret = np.asarray(tlwh).copy()
182
205
  ret[:2] += ret[2:] / 2
183
206
  ret[2] /= ret[3]
@@ -185,14 +208,14 @@ class STrack(BaseTrack):
185
208
 
186
209
  @property
187
210
  def xywh(self):
188
- """Get current position in bounding box format (center x, center y, width, height)."""
211
+ """Returns the current position of the bounding box in (center x, center y, width, height) format."""
189
212
  ret = np.asarray(self.tlwh).copy()
190
213
  ret[:2] += ret[2:] / 2
191
214
  return ret
192
215
 
193
216
  @property
194
217
  def xywha(self):
195
- """Get current position in bounding box format (center x, center y, width, height, angle)."""
218
+ """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
196
219
  if self.angle is None:
197
220
  LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
198
221
  return self.xywh
@@ -200,12 +223,12 @@ class STrack(BaseTrack):
200
223
 
201
224
  @property
202
225
  def result(self):
203
- """Get current tracking results."""
226
+ """Returns the current tracking results in the appropriate bounding box format."""
204
227
  coords = self.xyxy if self.angle is None else self.xywha
205
228
  return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
206
229
 
207
230
  def __repr__(self):
208
- """Return a string representation of the BYTETracker object with start and end frames and track ID."""
231
+ """Returns a string representation of the STrack object including start frame, end frame, and track ID."""
209
232
  return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
210
233
 
211
234
 
@@ -213,18 +236,18 @@ class BYTETracker:
213
236
  """
214
237
  BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
215
238
 
216
- The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
217
- sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
218
- predicting the new object locations, and performs data association.
239
+ Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
240
+ It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
241
+ the new object locations, and performs data association.
219
242
 
220
243
  Attributes:
221
- tracked_stracks (list[STrack]): List of successfully activated tracks.
222
- lost_stracks (list[STrack]): List of lost tracks.
223
- removed_stracks (list[STrack]): List of removed tracks.
244
+ tracked_stracks (List[STrack]): List of successfully activated tracks.
245
+ lost_stracks (List[STrack]): List of lost tracks.
246
+ removed_stracks (List[STrack]): List of removed tracks.
224
247
  frame_id (int): The current frame ID.
225
- args (namespace): Command-line arguments.
248
+ args (Namespace): Command-line arguments.
226
249
  max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
227
- kalman_filter (object): Kalman Filter object.
250
+ kalman_filter (KalmanFilterXYAH): Kalman Filter object.
228
251
 
229
252
  Methods:
230
253
  update(results, img=None): Updates object tracker with new detections.
@@ -236,10 +259,27 @@ class BYTETracker:
236
259
  joint_stracks(tlista, tlistb): Combines two lists of stracks.
237
260
  sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
238
261
  remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
262
+
263
+ Examples:
264
+ Initialize BYTETracker and update with detection results
265
+ >>> tracker = BYTETracker(args, frame_rate=30)
266
+ >>> results = yolo_model.detect(image)
267
+ >>> tracked_objects = tracker.update(results)
239
268
  """
240
269
 
241
270
  def __init__(self, args, frame_rate=30):
242
- """Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
271
+ """
272
+ Initialize a BYTETracker instance for object tracking.
273
+
274
+ Args:
275
+ args (Namespace): Command-line arguments containing tracking parameters.
276
+ frame_rate (int): Frame rate of the video sequence.
277
+
278
+ Examples:
279
+ Initialize BYTETracker with command-line arguments and a frame rate of 30
280
+ >>> args = Namespace(track_buffer=30)
281
+ >>> tracker = BYTETracker(args, frame_rate=30)
282
+ """
243
283
  self.tracked_stracks = [] # type: list[STrack]
244
284
  self.lost_stracks = [] # type: list[STrack]
245
285
  self.removed_stracks = [] # type: list[STrack]
@@ -251,7 +291,7 @@ class BYTETracker:
251
291
  self.reset_id()
252
292
 
253
293
  def update(self, results, img=None):
254
- """Updates object tracker with new detections and returns tracked object bounding boxes."""
294
+ """Updates the tracker with new detections and returns the current list of tracked objects."""
255
295
  self.frame_id += 1
256
296
  activated_stracks = []
257
297
  refind_stracks = []
@@ -264,11 +304,11 @@ class BYTETracker:
264
304
  bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
265
305
  cls = results.cls
266
306
 
267
- remain_inds = scores > self.args.track_high_thresh
307
+ remain_inds = scores >= self.args.track_high_thresh
268
308
  inds_low = scores > self.args.track_low_thresh
269
309
  inds_high = scores < self.args.track_high_thresh
270
310
 
271
- inds_second = np.logical_and(inds_low, inds_high)
311
+ inds_second = inds_low & inds_high
272
312
  dets_second = bboxes[inds_second]
273
313
  dets = bboxes[remain_inds]
274
314
  scores_keep = scores[remain_inds]
@@ -365,32 +405,31 @@ class BYTETracker:
365
405
  return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
366
406
 
367
407
  def get_kalmanfilter(self):
368
- """Returns a Kalman filter object for tracking bounding boxes."""
408
+ """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
369
409
  return KalmanFilterXYAH()
370
410
 
371
411
  def init_track(self, dets, scores, cls, img=None):
372
- """Initialize object tracking with detections and scores using STrack algorithm."""
412
+ """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
373
413
  return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
374
414
 
375
415
  def get_dists(self, tracks, detections):
376
- """Calculates the distance between tracks and detections using IoU and fuses scores."""
416
+ """Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
377
417
  dists = matching.iou_distance(tracks, detections)
378
- # TODO: mot20
379
- # if not self.args.mot20:
380
- dists = matching.fuse_score(dists, detections)
418
+ if self.args.fuse_score:
419
+ dists = matching.fuse_score(dists, detections)
381
420
  return dists
382
421
 
383
422
  def multi_predict(self, tracks):
384
- """Returns the predicted tracks using the YOLOv8 network."""
423
+ """Predict the next states for multiple tracks using Kalman filter."""
385
424
  STrack.multi_predict(tracks)
386
425
 
387
426
  @staticmethod
388
427
  def reset_id():
389
- """Resets the ID counter of STrack."""
428
+ """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
390
429
  STrack.reset_id()
391
430
 
392
431
  def reset(self):
393
- """Reset tracker."""
432
+ """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
394
433
  self.tracked_stracks = [] # type: list[STrack]
395
434
  self.lost_stracks = [] # type: list[STrack]
396
435
  self.removed_stracks = [] # type: list[STrack]
@@ -400,7 +439,7 @@ class BYTETracker:
400
439
 
401
440
  @staticmethod
402
441
  def joint_stracks(tlista, tlistb):
403
- """Combine two lists of stracks into a single one."""
442
+ """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
404
443
  exists = {}
405
444
  res = []
406
445
  for t in tlista:
@@ -415,20 +454,13 @@ class BYTETracker:
415
454
 
416
455
  @staticmethod
417
456
  def sub_stracks(tlista, tlistb):
418
- """DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
419
- stracks = {t.track_id: t for t in tlista}
420
- for t in tlistb:
421
- tid = t.track_id
422
- if stracks.get(tid, 0):
423
- del stracks[tid]
424
- return list(stracks.values())
425
- """
457
+ """Filters out the stracks present in the second list from the first list."""
426
458
  track_ids_b = {t.track_id for t in tlistb}
427
459
  return [t for t in tlista if t.track_id not in track_ids_b]
428
460
 
429
461
  @staticmethod
430
462
  def remove_duplicate_stracks(stracksa, stracksb):
431
- """Remove duplicate stracks with non-maximum IoU distance."""
463
+ """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
432
464
  pdist = matching.iou_distance(stracksa, stracksb)
433
465
  pairs = np.where(pdist < 0.15)
434
466
  dupa, dupb = [], []
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from functools import partial
4
4
  from pathlib import Path
@@ -7,6 +7,7 @@ import torch
7
7
 
8
8
  from ultralytics.utils import IterableSimpleNamespace, yaml_load
9
9
  from ultralytics.utils.checks import check_yaml
10
+
10
11
  from .bot_sort import BOTSORT
11
12
  from .byte_tracker import BYTETracker
12
13
 
@@ -20,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
20
21
 
21
22
  Args:
22
23
  predictor (object): The predictor object to initialize trackers for.
23
- persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
24
+ persist (bool): Whether to persist the trackers if they already exist.
24
25
 
25
26
  Raises:
26
27
  AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
28
+
29
+ Examples:
30
+ Initialize trackers for a predictor object:
31
+ >>> predictor = SomePredictorClass()
32
+ >>> on_predict_start(predictor, persist=True)
27
33
  """
28
34
  if hasattr(predictor, "trackers") and persist:
29
35
  return
@@ -31,7 +37,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
31
37
  tracker = check_yaml(predictor.args.tracker)
32
38
  cfg = IterableSimpleNamespace(**yaml_load(tracker))
33
39
 
34
- if cfg.tracker_type not in ["bytetrack", "botsort"]:
40
+ if cfg.tracker_type not in {"bytetrack", "botsort"}:
35
41
  raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
36
42
 
37
43
  trackers = []
@@ -50,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
50
56
 
51
57
  Args:
52
58
  predictor (object): The predictor object containing the predictions.
53
- persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
59
+ persist (bool): Whether to persist the trackers if they already exist.
60
+
61
+ Examples:
62
+ Postprocess predictions and update with tracking
63
+ >>> predictor = YourPredictorClass()
64
+ >>> on_predict_postprocess_end(predictor, persist=True)
54
65
  """
55
66
  path, im0s = predictor.batch[:2]
56
67
 
@@ -72,8 +83,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
72
83
  idx = tracks[:, -1].astype(int)
73
84
  predictor.results[i] = predictor.results[i][idx]
74
85
 
75
- update_args = dict()
76
- update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
86
+ update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
77
87
  predictor.results[i].update(**update_args)
78
88
 
79
89
 
@@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None:
84
94
  Args:
85
95
  model (object): The model object to register tracking callbacks for.
86
96
  persist (bool): Whether to persist the trackers if they already exist.
97
+
98
+ Examples:
99
+ Register tracking callbacks to a YOLO model
100
+ >>> model = YOLOModel()
101
+ >>> register_tracker(model, persist=True)
87
102
  """
88
103
  model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
89
104
  model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
@@ -1 +1 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import copy
4
4
 
@@ -19,32 +19,44 @@ class GMC:
19
19
  method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
20
20
  downscale (int): Factor by which to downscale the frames for processing.
21
21
  prevFrame (np.ndarray): Stores the previous frame for tracking.
22
- prevKeyPoints (list): Stores the keypoints from the previous frame.
22
+ prevKeyPoints (List): Stores the keypoints from the previous frame.
23
23
  prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
24
24
  initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
25
25
 
26
26
  Methods:
27
- __init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
28
- and downscale factor.
29
- apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
30
- provided detections.
31
- applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
32
- applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
33
- applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
27
+ __init__: Initializes a GMC object with the specified method and downscale factor.
28
+ apply: Applies the chosen method to a raw frame and optionally uses provided detections.
29
+ apply_ecc: Applies the ECC algorithm to a raw frame.
30
+ apply_features: Applies feature-based methods like ORB or SIFT to a raw frame.
31
+ apply_sparseoptflow: Applies the Sparse Optical Flow method to a raw frame.
32
+ reset_params: Resets the internal parameters of the GMC object.
33
+
34
+ Examples:
35
+ Create a GMC object and apply it to a frame
36
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
37
+ >>> frame = np.array([[1, 2, 3], [4, 5, 6]])
38
+ >>> processed_frame = gmc.apply(frame)
39
+ >>> print(processed_frame)
40
+ array([[1, 2, 3],
41
+ [4, 5, 6]])
34
42
  """
35
43
 
36
44
  def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
37
45
  """
38
- Initialize a video tracker with specified parameters.
46
+ Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
39
47
 
40
48
  Args:
41
49
  method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
42
50
  downscale (int): Downscale factor for processing frames.
51
+
52
+ Examples:
53
+ Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
54
+ >>> gmc = GMC(method="sparseOptFlow", downscale=2)
43
55
  """
44
56
  super().__init__()
45
57
 
46
58
  self.method = method
47
- self.downscale = max(1, int(downscale))
59
+ self.downscale = max(1, downscale)
48
60
 
49
61
  if self.method == "orb":
50
62
  self.detector = cv2.FastFeatureDetector_create(20)
@@ -79,45 +91,47 @@ class GMC:
79
91
 
80
92
  def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
81
93
  """
82
- Apply object detection on a raw frame using specified method.
94
+ Apply object detection on a raw frame using the specified method.
83
95
 
84
96
  Args:
85
- raw_frame (np.ndarray): The raw frame to be processed.
86
- detections (list): List of detections to be used in the processing.
97
+ raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
98
+ detections (List | None): List of detections to be used in the processing.
87
99
 
88
100
  Returns:
89
- (np.ndarray): Processed frame.
101
+ (np.ndarray): Processed frame with applied object detection.
90
102
 
91
103
  Examples:
92
- >>> gmc = GMC()
93
- >>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
94
- array([[1, 2, 3],
95
- [4, 5, 6]])
104
+ >>> gmc = GMC(method="sparseOptFlow")
105
+ >>> raw_frame = np.random.rand(480, 640, 3)
106
+ >>> processed_frame = gmc.apply(raw_frame)
107
+ >>> print(processed_frame.shape)
108
+ (480, 640, 3)
96
109
  """
97
- if self.method in ["orb", "sift"]:
98
- return self.applyFeatures(raw_frame, detections)
110
+ if self.method in {"orb", "sift"}:
111
+ return self.apply_features(raw_frame, detections)
99
112
  elif self.method == "ecc":
100
- return self.applyEcc(raw_frame)
113
+ return self.apply_ecc(raw_frame)
101
114
  elif self.method == "sparseOptFlow":
102
- return self.applySparseOptFlow(raw_frame)
115
+ return self.apply_sparseoptflow(raw_frame)
103
116
  else:
104
117
  return np.eye(2, 3)
105
118
 
106
- def applyEcc(self, raw_frame: np.array) -> np.array:
119
+ def apply_ecc(self, raw_frame: np.array) -> np.array:
107
120
  """
108
- Apply ECC algorithm to a raw frame.
121
+ Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
109
122
 
110
123
  Args:
111
- raw_frame (np.ndarray): The raw frame to be processed.
124
+ raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
112
125
 
113
126
  Returns:
114
- (np.ndarray): Processed frame.
127
+ (np.ndarray): The processed frame with the applied ECC transformation.
115
128
 
116
129
  Examples:
117
- >>> gmc = GMC()
118
- >>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
119
- array([[1, 2, 3],
120
- [4, 5, 6]])
130
+ >>> gmc = GMC(method="ecc")
131
+ >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
132
+ >>> print(processed_frame)
133
+ [[1. 0. 0.]
134
+ [0. 1. 0.]]
121
135
  """
122
136
  height, width, _ = raw_frame.shape
123
137
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@@ -127,8 +141,6 @@ class GMC:
127
141
  if self.downscale > 1.0:
128
142
  frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
129
143
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
130
- width = width // self.downscale
131
- height = height // self.downscale
132
144
 
133
145
  # Handle first frame
134
146
  if not self.initializedFirstFrame:
@@ -149,22 +161,23 @@ class GMC:
149
161
 
150
162
  return H
151
163
 
152
- def applyFeatures(self, raw_frame: np.array, detections: list = None) -> np.array:
164
+ def apply_features(self, raw_frame: np.array, detections: list = None) -> np.array:
153
165
  """
154
166
  Apply feature-based methods like ORB or SIFT to a raw frame.
155
167
 
156
168
  Args:
157
- raw_frame (np.ndarray): The raw frame to be processed.
158
- detections (list): List of detections to be used in the processing.
169
+ raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
170
+ detections (List | None): List of detections to be used in the processing.
159
171
 
160
172
  Returns:
161
173
  (np.ndarray): Processed frame.
162
174
 
163
175
  Examples:
164
- >>> gmc = GMC()
165
- >>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
166
- array([[1, 2, 3],
167
- [4, 5, 6]])
176
+ >>> gmc = GMC(method="orb")
177
+ >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
178
+ >>> processed_frame = gmc.apply_features(raw_frame)
179
+ >>> print(processed_frame.shape)
180
+ (2, 3)
168
181
  """
169
182
  height, width, _ = raw_frame.shape
170
183
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@@ -291,21 +304,22 @@ class GMC:
291
304
 
292
305
  return H
293
306
 
294
- def applySparseOptFlow(self, raw_frame: np.array) -> np.array:
307
+ def apply_sparseoptflow(self, raw_frame: np.array) -> np.array:
295
308
  """
296
309
  Apply Sparse Optical Flow method to a raw frame.
297
310
 
298
311
  Args:
299
- raw_frame (np.ndarray): The raw frame to be processed.
312
+ raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
300
313
 
301
314
  Returns:
302
- (np.ndarray): Processed frame.
315
+ (np.ndarray): Processed frame with shape (2, 3).
303
316
 
304
317
  Examples:
305
318
  >>> gmc = GMC()
306
- >>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
307
- array([[1, 2, 3],
308
- [4, 5, 6]])
319
+ >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
320
+ >>> print(result)
321
+ [[1. 0. 0.]
322
+ [0. 1. 0.]]
309
323
  """
310
324
  height, width, _ = raw_frame.shape
311
325
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@@ -319,7 +333,7 @@ class GMC:
319
333
  keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
320
334
 
321
335
  # Handle first frame
322
- if not self.initializedFirstFrame:
336
+ if not self.initializedFirstFrame or self.prevKeyPoints is None:
323
337
  self.prevFrame = frame.copy()
324
338
  self.prevKeyPoints = copy.copy(keypoints)
325
339
  self.initializedFirstFrame = True
@@ -356,7 +370,7 @@ class GMC:
356
370
  return H
357
371
 
358
372
  def reset_params(self) -> None:
359
- """Reset parameters."""
373
+ """Reset the internal parameters including previous frame, keypoints, and descriptors."""
360
374
  self.prevFrame = None
361
375
  self.prevKeyPoints = None
362
376
  self.prevDescriptors = None