dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -2,13 +2,13 @@
2
2
  """Module defines the base classes and structures for object tracking in YOLO."""
3
3
 
4
4
  from collections import OrderedDict
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
 
8
9
 
9
10
  class TrackState:
10
- """
11
- Enumeration class representing the possible states of an object being tracked.
11
+ """Enumeration class representing the possible states of an object being tracked.
12
12
 
13
13
  Attributes:
14
14
  New (int): State when the object is newly detected.
@@ -29,8 +29,7 @@ class TrackState:
29
29
 
30
30
 
31
31
  class BaseTrack:
32
- """
33
- Base class for object tracking, providing foundational attributes and methods.
32
+ """Base class for object tracking, providing foundational attributes and methods.
34
33
 
35
34
  Attributes:
36
35
  _count (int): Class-level counter for unique track IDs.
@@ -66,15 +65,7 @@ class BaseTrack:
66
65
  _count = 0
67
66
 
68
67
  def __init__(self):
69
- """
70
- Initialize a new track with a unique ID and foundational tracking attributes.
71
-
72
- Examples:
73
- Initialize a new track
74
- >>> track = BaseTrack()
75
- >>> print(track.track_id)
76
- 0
77
- """
68
+ """Initialize a new track with a unique ID and foundational tracking attributes."""
78
69
  self.track_id = 0
79
70
  self.is_activated = False
80
71
  self.state = TrackState.New
@@ -88,37 +79,37 @@ class BaseTrack:
88
79
  self.location = (np.inf, np.inf)
89
80
 
90
81
  @property
91
- def end_frame(self):
92
- """Returns the ID of the most recent frame where the object was tracked."""
82
+ def end_frame(self) -> int:
83
+ """Return the ID of the most recent frame where the object was tracked."""
93
84
  return self.frame_id
94
85
 
95
86
  @staticmethod
96
- def next_id():
87
+ def next_id() -> int:
97
88
  """Increment and return the next unique global track ID for object tracking."""
98
89
  BaseTrack._count += 1
99
90
  return BaseTrack._count
100
91
 
101
- def activate(self, *args):
102
- """Activates the track with provided arguments, initializing necessary attributes for tracking."""
92
+ def activate(self, *args: Any) -> None:
93
+ """Activate the track with provided arguments, initializing necessary attributes for tracking."""
103
94
  raise NotImplementedError
104
95
 
105
- def predict(self):
106
- """Predicts the next state of the track based on the current state and tracking model."""
96
+ def predict(self) -> None:
97
+ """Predict the next state of the track based on the current state and tracking model."""
107
98
  raise NotImplementedError
108
99
 
109
- def update(self, *args, **kwargs):
110
- """Updates the track with new observations and data, modifying its state and attributes accordingly."""
100
+ def update(self, *args: Any, **kwargs: Any) -> None:
101
+ """Update the track with new observations and data, modifying its state and attributes accordingly."""
111
102
  raise NotImplementedError
112
103
 
113
- def mark_lost(self):
114
- """Marks the track as lost by updating its state to TrackState.Lost."""
104
+ def mark_lost(self) -> None:
105
+ """Mark the track as lost by updating its state to TrackState.Lost."""
115
106
  self.state = TrackState.Lost
116
107
 
117
- def mark_removed(self):
118
- """Marks the track as removed by setting its state to TrackState.Removed."""
108
+ def mark_removed(self) -> None:
109
+ """Mark the track as removed by setting its state to TrackState.Removed."""
119
110
  self.state = TrackState.Removed
120
111
 
121
112
  @staticmethod
122
- def reset_id():
113
+ def reset_id() -> None:
123
114
  """Reset the global track ID counter to its initial value."""
124
115
  BaseTrack._count = 0
@@ -1,6 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from collections import deque
6
+ from typing import Any
4
7
 
5
8
  import numpy as np
6
9
  import torch
@@ -16,8 +19,7 @@ from .utils.kalman_filter import KalmanFilterXYWH
16
19
 
17
20
 
18
21
  class BOTrack(STrack):
19
- """
20
- An extended version of the STrack class for YOLO, adding object tracking features.
22
+ """An extended version of the STrack class for YOLO, adding object tracking features.
21
23
 
22
24
  This class extends the STrack class to include additional functionalities for object tracking, such as feature
23
25
  smoothing, Kalman filter prediction, and reactivation of tracks.
@@ -51,26 +53,27 @@ class BOTrack(STrack):
51
53
 
52
54
  shared_kalman = KalmanFilterXYWH()
53
55
 
54
- def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
55
- """
56
- Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
56
+ def __init__(
57
+ self, xywh: np.ndarray, score: float, cls: int, feat: np.ndarray | None = None, feat_history: int = 50
58
+ ):
59
+ """Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
57
60
 
58
61
  Args:
59
- tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
62
+ xywh (np.ndarray): Bounding box coordinates in xywh format (center x, center y, width, height).
60
63
  score (float): Confidence score of the detection.
61
64
  cls (int): Class ID of the detected object.
62
- feat (np.ndarray | None): Feature vector associated with the detection.
65
+ feat (np.ndarray, optional): Feature vector associated with the detection.
63
66
  feat_history (int): Maximum length of the feature history deque.
64
67
 
65
68
  Examples:
66
69
  Initialize a BOTrack object with bounding box, score, class ID, and feature vector
67
- >>> tlwh = np.array([100, 50, 80, 120])
70
+ >>> xywh = np.array([100, 150, 60, 50])
68
71
  >>> score = 0.9
69
72
  >>> cls = 1
70
73
  >>> feat = np.random.rand(128)
71
- >>> bo_track = BOTrack(tlwh, score, cls, feat)
74
+ >>> bo_track = BOTrack(xywh, score, cls, feat)
72
75
  """
73
- super().__init__(tlwh, score, cls)
76
+ super().__init__(xywh, score, cls)
74
77
 
75
78
  self.smooth_feat = None
76
79
  self.curr_feat = None
@@ -79,7 +82,7 @@ class BOTrack(STrack):
79
82
  self.features = deque([], maxlen=feat_history)
80
83
  self.alpha = 0.9
81
84
 
82
- def update_features(self, feat):
85
+ def update_features(self, feat: np.ndarray) -> None:
83
86
  """Update the feature vector and apply exponential moving average smoothing."""
84
87
  feat /= np.linalg.norm(feat)
85
88
  self.curr_feat = feat
@@ -90,7 +93,7 @@ class BOTrack(STrack):
90
93
  self.features.append(feat)
91
94
  self.smooth_feat /= np.linalg.norm(self.smooth_feat)
92
95
 
93
- def predict(self):
96
+ def predict(self) -> None:
94
97
  """Predict the object's future state using the Kalman filter to update its mean and covariance."""
95
98
  mean_state = self.mean.copy()
96
99
  if self.state != TrackState.Tracked:
@@ -99,20 +102,20 @@ class BOTrack(STrack):
99
102
 
100
103
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
101
104
 
102
- def re_activate(self, new_track, frame_id, new_id=False):
105
+ def re_activate(self, new_track: BOTrack, frame_id: int, new_id: bool = False) -> None:
103
106
  """Reactivate a track with updated features and optionally assign a new ID."""
104
107
  if new_track.curr_feat is not None:
105
108
  self.update_features(new_track.curr_feat)
106
109
  super().re_activate(new_track, frame_id, new_id)
107
110
 
108
- def update(self, new_track, frame_id):
111
+ def update(self, new_track: BOTrack, frame_id: int) -> None:
109
112
  """Update the track with new detection information and the current frame ID."""
110
113
  if new_track.curr_feat is not None:
111
114
  self.update_features(new_track.curr_feat)
112
115
  super().update(new_track, frame_id)
113
116
 
114
117
  @property
115
- def tlwh(self):
118
+ def tlwh(self) -> np.ndarray:
116
119
  """Return the current bounding box position in `(top left x, top left y, width, height)` format."""
117
120
  if self.mean is None:
118
121
  return self._tlwh.copy()
@@ -121,7 +124,7 @@ class BOTrack(STrack):
121
124
  return ret
122
125
 
123
126
  @staticmethod
124
- def multi_predict(stracks):
127
+ def multi_predict(stracks: list[BOTrack]) -> None:
125
128
  """Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
126
129
  if len(stracks) <= 0:
127
130
  return
@@ -136,12 +139,12 @@ class BOTrack(STrack):
136
139
  stracks[i].mean = mean
137
140
  stracks[i].covariance = cov
138
141
 
139
- def convert_coords(self, tlwh):
142
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
140
143
  """Convert tlwh bounding box coordinates to xywh format."""
141
144
  return self.tlwh_to_xywh(tlwh)
142
145
 
143
146
  @staticmethod
144
- def tlwh_to_xywh(tlwh):
147
+ def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
145
148
  """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
146
149
  ret = np.asarray(tlwh).copy()
147
150
  ret[:2] += ret[2:] / 2
@@ -149,8 +152,7 @@ class BOTrack(STrack):
149
152
 
150
153
 
151
154
  class BOTSORT(BYTETracker):
152
- """
153
- An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
155
+ """An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
154
156
 
155
157
  Attributes:
156
158
  proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
@@ -172,16 +174,15 @@ class BOTSORT(BYTETracker):
172
174
  >>> bot_sort.init_track(dets, scores, cls, img)
173
175
  >>> bot_sort.multi_predict(tracks)
174
176
 
175
- Note:
177
+ Notes:
176
178
  The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
177
179
  """
178
180
 
179
- def __init__(self, args, frame_rate=30):
180
- """
181
- Initialize BOTSORT object with ReID module and GMC algorithm.
181
+ def __init__(self, args: Any, frame_rate: int = 30):
182
+ """Initialize BOTSORT object with ReID module and GMC algorithm.
182
183
 
183
184
  Args:
184
- args (object): Parsed command-line arguments containing tracking parameters.
185
+ args (Any): Parsed command-line arguments containing tracking parameters.
185
186
  frame_rate (int): Frame rate of the video being processed.
186
187
 
187
188
  Examples:
@@ -203,21 +204,23 @@ class BOTSORT(BYTETracker):
203
204
  else None
204
205
  )
205
206
 
206
- def get_kalmanfilter(self):
207
+ def get_kalmanfilter(self) -> KalmanFilterXYWH:
207
208
  """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
208
209
  return KalmanFilterXYWH()
209
210
 
210
- def init_track(self, dets, scores, cls, img=None):
211
+ def init_track(self, results, img: np.ndarray | None = None) -> list[BOTrack]:
211
212
  """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
212
- if len(dets) == 0:
213
+ if len(results) == 0:
213
214
  return []
215
+ bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
216
+ bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
214
217
  if self.args.with_reid and self.encoder is not None:
215
- features_keep = self.encoder(img, dets)
216
- return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
218
+ features_keep = self.encoder(img, bboxes)
219
+ return [BOTrack(xywh, s, c, f) for (xywh, s, c, f) in zip(bboxes, results.conf, results.cls, features_keep)]
217
220
  else:
218
- return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
221
+ return [BOTrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
219
222
 
220
- def get_dists(self, tracks, detections):
223
+ def get_dists(self, tracks: list[BOTrack], detections: list[BOTrack]) -> np.ndarray:
221
224
  """Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
222
225
  dists = matching.iou_distance(tracks, detections)
223
226
  dists_mask = dists > (1 - self.proximity_thresh)
@@ -232,11 +235,11 @@ class BOTSORT(BYTETracker):
232
235
  dists = np.minimum(dists, emb_dists)
233
236
  return dists
234
237
 
235
- def multi_predict(self, tracks):
238
+ def multi_predict(self, tracks: list[BOTrack]) -> None:
236
239
  """Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
237
240
  BOTrack.multi_predict(tracks)
238
241
 
239
- def reset(self):
242
+ def reset(self) -> None:
240
243
  """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
241
244
  super().reset()
242
245
  self.gmc.reset_params()
@@ -245,16 +248,22 @@ class BOTSORT(BYTETracker):
245
248
  class ReID:
246
249
  """YOLO model as encoder for re-identification."""
247
250
 
248
- def __init__(self, model):
249
- """Initialize encoder for re-identification."""
251
+ def __init__(self, model: str):
252
+ """Initialize encoder for re-identification.
253
+
254
+ Args:
255
+ model (str): Path to the YOLO model for re-identification.
256
+ """
250
257
  from ultralytics import YOLO
251
258
 
252
259
  self.model = YOLO(model)
253
- self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
260
+ self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
254
261
 
255
- def __call__(self, img, dets):
262
+ def __call__(self, img: np.ndarray, dets: np.ndarray) -> list[np.ndarray]:
256
263
  """Extract embeddings for detected objects."""
257
- feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
264
+ feats = self.model.predictor(
265
+ [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
266
+ )
258
267
  if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
259
268
  feats = feats[0] # batched prediction with non-PyTorch backend
260
269
  return [f.cpu().numpy() for f in feats]