dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
3
7
  import numpy as np
4
8
 
5
9
  from ..utils import LOGGER
@@ -10,8 +14,7 @@ from .utils.kalman_filter import KalmanFilterXYAH
10
14
 
11
15
 
12
16
  class STrack(BaseTrack):
13
- """
14
- Single object tracking representation that uses Kalman filtering for state estimation.
17
+ """Single object tracking representation that uses Kalman filtering for state estimation.
15
18
 
16
19
  This class is responsible for storing all the information regarding individual tracklets and performs state updates
17
20
  and predictions based on Kalman filter.
@@ -29,16 +32,17 @@ class STrack(BaseTrack):
29
32
  idx (int): Index or identifier for the object.
30
33
  frame_id (int): Current frame ID.
31
34
  start_frame (int): Frame where the object was first detected.
35
+ angle (float | None): Optional angle information for oriented bounding boxes.
32
36
 
33
37
  Methods:
34
- predict(): Predict the next state of the object using Kalman filter.
35
- multi_predict(stracks): Predict the next states for multiple tracks.
36
- multi_gmc(stracks, H): Update multiple track states using a homography matrix.
37
- activate(kalman_filter, frame_id): Activate a new tracklet.
38
- re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
39
- update(new_track, frame_id): Update the state of a matched track.
40
- convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
41
- tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
38
+ predict: Predict the next state of the object using Kalman filter.
39
+ multi_predict: Predict the next states for multiple tracks.
40
+ multi_gmc: Update multiple track states using a homography matrix.
41
+ activate: Activate a new tracklet.
42
+ re_activate: Reactivate a previously lost tracklet.
43
+ update: Update the state of a matched track.
44
+ convert_coords: Convert bounding box to x-y-aspect-height format.
45
+ tlwh_to_xyah: Convert tlwh bounding box to xyah format.
42
46
 
43
47
  Examples:
44
48
  Initialize and activate a new track
@@ -48,13 +52,12 @@ class STrack(BaseTrack):
48
52
 
49
53
  shared_kalman = KalmanFilterXYAH()
50
54
 
51
- def __init__(self, xywh, score, cls):
52
- """
53
- Initialize a new STrack instance.
55
+ def __init__(self, xywh: list[float], score: float, cls: Any):
56
+ """Initialize a new STrack instance.
54
57
 
55
58
  Args:
56
- xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
57
- (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
59
+ xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where (x,
60
+ y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
58
61
  score (float): Confidence score of the detection.
59
62
  cls (Any): Class label for the detected object.
60
63
 
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
79
82
  self.angle = xywh[4] if len(xywh) == 6 else None
80
83
 
81
84
  def predict(self):
82
- """Predicts the next state (mean and covariance) of the object using the Kalman filter."""
85
+ """Predict the next state (mean and covariance) of the object using the Kalman filter."""
83
86
  mean_state = self.mean.copy()
84
87
  if self.state != TrackState.Tracked:
85
88
  mean_state[7] = 0
86
89
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
87
90
 
88
91
  @staticmethod
89
- def multi_predict(stracks):
92
+ def multi_predict(stracks: list[STrack]):
90
93
  """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
91
94
  if len(stracks) <= 0:
92
95
  return
@@ -101,9 +104,9 @@ class STrack(BaseTrack):
101
104
  stracks[i].covariance = cov
102
105
 
103
106
  @staticmethod
104
- def multi_gmc(stracks, H=np.eye(2, 3)):
107
+ def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3)):
105
108
  """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
106
- if len(stracks) > 0:
109
+ if stracks:
107
110
  multi_mean = np.asarray([st.mean.copy() for st in stracks])
108
111
  multi_covariance = np.asarray([st.covariance for st in stracks])
109
112
 
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
119
122
  stracks[i].mean = mean
120
123
  stracks[i].covariance = cov
121
124
 
122
- def activate(self, kalman_filter, frame_id):
125
+ def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
123
126
  """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
124
127
  self.kalman_filter = kalman_filter
125
128
  self.track_id = self.next_id()
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
132
135
  self.frame_id = frame_id
133
136
  self.start_frame = frame_id
134
137
 
135
- def re_activate(self, new_track, frame_id, new_id=False):
136
- """Reactivates a previously lost track using new detection data and updates its state and attributes."""
138
+ def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False):
139
+ """Reactivate a previously lost track using new detection data and update its state and attributes."""
137
140
  self.mean, self.covariance = self.kalman_filter.update(
138
141
  self.mean, self.covariance, self.convert_coords(new_track.tlwh)
139
142
  )
@@ -148,9 +151,8 @@ class STrack(BaseTrack):
148
151
  self.angle = new_track.angle
149
152
  self.idx = new_track.idx
150
153
 
151
- def update(self, new_track, frame_id):
152
- """
153
- Update the state of a matched track.
154
+ def update(self, new_track: STrack, frame_id: int):
155
+ """Update the state of a matched track.
154
156
 
155
157
  Args:
156
158
  new_track (STrack): The new track containing updated information.
@@ -177,13 +179,13 @@ class STrack(BaseTrack):
177
179
  self.angle = new_track.angle
178
180
  self.idx = new_track.idx
179
181
 
180
- def convert_coords(self, tlwh):
182
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
181
183
  """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
182
184
  return self.tlwh_to_xyah(tlwh)
183
185
 
184
186
  @property
185
- def tlwh(self):
186
- """Returns the bounding box in top-left-width-height format from the current state estimate."""
187
+ def tlwh(self) -> np.ndarray:
188
+ """Get the bounding box in top-left-width-height format from the current state estimate."""
187
189
  if self.mean is None:
188
190
  return self._tlwh.copy()
189
191
  ret = self.mean[:4].copy()
@@ -192,14 +194,14 @@ class STrack(BaseTrack):
192
194
  return ret
193
195
 
194
196
  @property
195
- def xyxy(self):
196
- """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
197
+ def xyxy(self) -> np.ndarray:
198
+ """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
197
199
  ret = self.tlwh.copy()
198
200
  ret[2:] += ret[:2]
199
201
  return ret
200
202
 
201
203
  @staticmethod
202
- def tlwh_to_xyah(tlwh):
204
+ def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
203
205
  """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
204
206
  ret = np.asarray(tlwh).copy()
205
207
  ret[:2] += ret[2:] / 2
@@ -207,58 +209,58 @@ class STrack(BaseTrack):
207
209
  return ret
208
210
 
209
211
  @property
210
- def xywh(self):
211
- """Returns the current position of the bounding box in (center x, center y, width, height) format."""
212
+ def xywh(self) -> np.ndarray:
213
+ """Get the current position of the bounding box in (center x, center y, width, height) format."""
212
214
  ret = np.asarray(self.tlwh).copy()
213
215
  ret[:2] += ret[2:] / 2
214
216
  return ret
215
217
 
216
218
  @property
217
- def xywha(self):
218
- """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
219
+ def xywha(self) -> np.ndarray:
220
+ """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
219
221
  if self.angle is None:
220
222
  LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
221
223
  return self.xywh
222
224
  return np.concatenate([self.xywh, self.angle[None]])
223
225
 
224
226
  @property
225
- def result(self):
226
- """Returns the current tracking results in the appropriate bounding box format."""
227
+ def result(self) -> list[float]:
228
+ """Get the current tracking results in the appropriate bounding box format."""
227
229
  coords = self.xyxy if self.angle is None else self.xywha
228
- return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
230
+ return [*coords.tolist(), self.track_id, self.score, self.cls, self.idx]
229
231
 
230
- def __repr__(self):
231
- """Returns a string representation of the STrack object including start frame, end frame, and track ID."""
232
+ def __repr__(self) -> str:
233
+ """Return a string representation of the STrack object including start frame, end frame, and track ID."""
232
234
  return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
233
235
 
234
236
 
235
237
  class BYTETracker:
236
- """
237
- BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
238
+ """BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
238
239
 
239
- This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a
240
- video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
241
- predicting the new object locations, and performs data association.
240
+ This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
241
+ in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
242
+ filtering for predicting the new object locations, and performs data association.
242
243
 
243
244
  Attributes:
244
- tracked_stracks (List[STrack]): List of successfully activated tracks.
245
- lost_stracks (List[STrack]): List of lost tracks.
246
- removed_stracks (List[STrack]): List of removed tracks.
245
+ tracked_stracks (list[STrack]): List of successfully activated tracks.
246
+ lost_stracks (list[STrack]): List of lost tracks.
247
+ removed_stracks (list[STrack]): List of removed tracks.
247
248
  frame_id (int): The current frame ID.
248
249
  args (Namespace): Command-line arguments.
249
250
  max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
250
251
  kalman_filter (KalmanFilterXYAH): Kalman Filter object.
251
252
 
252
253
  Methods:
253
- update(results, img=None): Updates object tracker with new detections.
254
- get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
255
- init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
256
- get_dists(tracks, detections): Calculates the distance between tracks and detections.
257
- multi_predict(tracks): Predicts the location of tracks.
258
- reset_id(): Resets the ID counter of STrack.
259
- joint_stracks(tlista, tlistb): Combines two lists of stracks.
260
- sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
261
- remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
254
+ update: Update object tracker with new detections.
255
+ get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
256
+ init_track: Initialize object tracking with detections.
257
+ get_dists: Calculate the distance between tracks and detections.
258
+ multi_predict: Predict the location of tracks.
259
+ reset_id: Reset the ID counter of STrack.
260
+ reset: Reset the tracker by clearing all tracks.
261
+ joint_stracks: Combine two lists of stracks.
262
+ sub_stracks: Filter out the stracks present in the second list from the first list.
263
+ remove_duplicate_stracks: Remove duplicate stracks based on IoU.
262
264
 
263
265
  Examples:
264
266
  Initialize BYTETracker and update with detection results
@@ -267,9 +269,8 @@ class BYTETracker:
267
269
  >>> tracked_objects = tracker.update(results)
268
270
  """
269
271
 
270
- def __init__(self, args, frame_rate=30):
271
- """
272
- Initialize a BYTETracker instance for object tracking.
272
+ def __init__(self, args, frame_rate: int = 30):
273
+ """Initialize a BYTETracker instance for object tracking.
273
274
 
274
275
  Args:
275
276
  args (Namespace): Command-line arguments containing tracking parameters.
@@ -290,8 +291,8 @@ class BYTETracker:
290
291
  self.kalman_filter = self.get_kalmanfilter()
291
292
  self.reset_id()
292
293
 
293
- def update(self, results, img=None, feats=None):
294
- """Updates the tracker with new detections and returns the current list of tracked objects."""
294
+ def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray:
295
+ """Update the tracker with new detections and return the current list of tracked objects."""
295
296
  self.frame_id += 1
296
297
  activated_stracks = []
297
298
  refind_stracks = []
@@ -299,24 +300,19 @@ class BYTETracker:
299
300
  removed_stracks = []
300
301
 
301
302
  scores = results.conf
302
- bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
303
- # Add index
304
- bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
305
- cls = results.cls
306
-
307
303
  remain_inds = scores >= self.args.track_high_thresh
308
304
  inds_low = scores > self.args.track_low_thresh
309
305
  inds_high = scores < self.args.track_high_thresh
310
306
 
311
307
  inds_second = inds_low & inds_high
312
- dets_second = bboxes[inds_second]
313
- dets = bboxes[remain_inds]
314
- scores_keep = scores[remain_inds]
315
- scores_second = scores[inds_second]
316
- cls_keep = cls[remain_inds]
317
- cls_second = cls[inds_second]
318
-
319
- detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)
308
+ results_second = results[inds_second]
309
+ results = results[remain_inds]
310
+ feats_keep = feats_second = img
311
+ if feats is not None and len(feats):
312
+ feats_keep = feats[remain_inds]
313
+ feats_second = feats[inds_second]
314
+
315
+ detections = self.init_track(results, feats_keep)
320
316
  # Add newly detected tracklets to tracked_stracks
321
317
  unconfirmed = []
322
318
  tracked_stracks = [] # type: list[STrack]
@@ -332,7 +328,7 @@ class BYTETracker:
332
328
  if hasattr(self, "gmc") and img is not None:
333
329
  # use try-except here to bypass errors from gmc module
334
330
  try:
335
- warp = self.gmc.apply(img, dets)
331
+ warp = self.gmc.apply(img, results.xyxy)
336
332
  except Exception:
337
333
  warp = np.eye(2, 3)
338
334
  STrack.multi_gmc(strack_pool, warp)
@@ -351,11 +347,11 @@ class BYTETracker:
351
347
  track.re_activate(det, self.frame_id, new_id=False)
352
348
  refind_stracks.append(track)
353
349
  # Step 3: Second association, with low score detection boxes association the untrack to the low score detections
354
- detections_second = self.init_track(dets_second, scores_second, cls_second, img if feats is None else feats)
350
+ detections_second = self.init_track(results_second, feats_second)
355
351
  r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
356
352
  # TODO
357
353
  dists = matching.iou_distance(r_tracked_stracks, detections_second)
358
- matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
354
+ matches, u_track, _u_detection_second = matching.linear_assignment(dists, thresh=0.5)
359
355
  for itracked, idet in matches:
360
356
  track = r_tracked_stracks[itracked]
361
357
  det = detections_second[idet]
@@ -408,32 +404,36 @@ class BYTETracker:
408
404
 
409
405
  return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
410
406
 
411
- def get_kalmanfilter(self):
412
- """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
407
+ def get_kalmanfilter(self) -> KalmanFilterXYAH:
408
+ """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
413
409
  return KalmanFilterXYAH()
414
410
 
415
- def init_track(self, dets, scores, cls, img=None):
416
- """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
417
- return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
411
+ def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]:
412
+ """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
413
+ if len(results) == 0:
414
+ return []
415
+ bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
416
+ bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
417
+ return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
418
418
 
419
- def get_dists(self, tracks, detections):
420
- """Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
419
+ def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray:
420
+ """Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
421
421
  dists = matching.iou_distance(tracks, detections)
422
422
  if self.args.fuse_score:
423
423
  dists = matching.fuse_score(dists, detections)
424
424
  return dists
425
425
 
426
- def multi_predict(self, tracks):
426
+ def multi_predict(self, tracks: list[STrack]):
427
427
  """Predict the next states for multiple tracks using Kalman filter."""
428
428
  STrack.multi_predict(tracks)
429
429
 
430
430
  @staticmethod
431
431
  def reset_id():
432
- """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
432
+ """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
433
433
  STrack.reset_id()
434
434
 
435
435
  def reset(self):
436
- """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
436
+ """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
437
437
  self.tracked_stracks = [] # type: list[STrack]
438
438
  self.lost_stracks = [] # type: list[STrack]
439
439
  self.removed_stracks = [] # type: list[STrack]
@@ -442,8 +442,8 @@ class BYTETracker:
442
442
  self.reset_id()
443
443
 
444
444
  @staticmethod
445
- def joint_stracks(tlista, tlistb):
446
- """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
445
+ def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
446
+ """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
447
447
  exists = {}
448
448
  res = []
449
449
  for t in tlista:
@@ -457,14 +457,14 @@ class BYTETracker:
457
457
  return res
458
458
 
459
459
  @staticmethod
460
- def sub_stracks(tlista, tlistb):
461
- """Filters out the stracks present in the second list from the first list."""
460
+ def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
461
+ """Filter out the stracks present in the second list from the first list."""
462
462
  track_ids_b = {t.track_id for t in tlistb}
463
463
  return [t for t in tlista if t.track_id not in track_ids_b]
464
464
 
465
465
  @staticmethod
466
- def remove_duplicate_stracks(stracksa, stracksb):
467
- """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
466
+ def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]:
467
+ """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
468
468
  pdist = matching.iou_distance(stracksa, stracksb)
469
469
  pairs = np.where(pdist < 0.15)
470
470
  dupa, dupb = [], []
@@ -16,19 +16,14 @@ TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
16
16
 
17
17
 
18
18
  def on_predict_start(predictor: object, persist: bool = False) -> None:
19
- """
20
- Initialize trackers for object tracking during prediction.
19
+ """Initialize trackers for object tracking during prediction.
21
20
 
22
21
  Args:
23
- predictor (object): The predictor object to initialize trackers for.
24
- persist (bool): Whether to persist the trackers if they already exist.
25
-
26
- Raises:
27
- AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
28
- ValueError: If the task is 'classify' as classification doesn't support tracking.
22
+ predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
23
+ persist (bool, optional): Whether to persist the trackers if they already exist.
29
24
 
30
25
  Examples:
31
- Initialize trackers for a predictor object:
26
+ Initialize trackers for a predictor object
32
27
  >>> predictor = SomePredictorClass()
33
28
  >>> on_predict_start(predictor, persist=True)
34
29
  """
@@ -74,12 +69,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
74
69
 
75
70
 
76
71
  def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
77
- """
78
- Postprocess detected boxes and update with object tracking.
72
+ """Postprocess detected boxes and update with object tracking.
79
73
 
80
74
  Args:
81
75
  predictor (object): The predictor object containing the predictions.
82
- persist (bool): Whether to persist the trackers if they already exist.
76
+ persist (bool, optional): Whether to persist the trackers if they already exist.
83
77
 
84
78
  Examples:
85
79
  Postprocess predictions and update with tracking
@@ -96,8 +90,6 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
96
90
  predictor.vid_path[i if is_stream else 0] = vid_path
97
91
 
98
92
  det = (result.obb if is_obb else result.boxes).cpu().numpy()
99
- if len(det) == 0:
100
- continue
101
93
  tracks = tracker.update(det, result.orig_img, getattr(result, "feats", None))
102
94
  if len(tracks) == 0:
103
95
  continue
@@ -109,8 +101,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
109
101
 
110
102
 
111
103
  def register_tracker(model: object, persist: bool) -> None:
112
- """
113
- Register tracking callbacks to the model for object tracking during prediction.
104
+ """Register tracking callbacks to the model for object tracking during prediction.
114
105
 
115
106
  Args:
116
107
  model (object): The model object to register tracking callbacks for.