dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import io
4
- from typing import Any
4
+ from typing import Any, List
5
5
 
6
6
  import cv2
7
7
 
@@ -24,7 +24,7 @@ class Inference:
24
24
  model_path (str): Path to the loaded model.
25
25
  model (YOLO): The YOLO model instance.
26
26
  source (str): Selected video source (webcam or video file).
27
- enable_trk (str): Enable tracking option ("Yes" or "No").
27
+ enable_trk (bool): Enable tracking option.
28
28
  conf (float): Confidence threshold for detection.
29
29
  iou (float): IoU threshold for non-maximum suppression.
30
30
  org_frame (Any): Container for the original frame to be displayed.
@@ -33,14 +33,19 @@ class Inference:
33
33
  selected_ind (List[int]): List of selected class indices for detection.
34
34
 
35
35
  Methods:
36
- web_ui: Sets up the Streamlit web interface with custom HTML elements.
37
- sidebar: Configures the Streamlit sidebar for model and inference settings.
38
- source_upload: Handles video file uploads through the Streamlit interface.
39
- configure: Configures the model and loads selected classes for inference.
40
- inference: Performs real-time object detection inference.
36
+ web_ui: Set up the Streamlit web interface with custom HTML elements.
37
+ sidebar: Configure the Streamlit sidebar for model and inference settings.
38
+ source_upload: Handle video file uploads through the Streamlit interface.
39
+ configure: Configure the model and load selected classes for inference.
40
+ inference: Perform real-time object detection inference.
41
41
 
42
42
  Examples:
43
- >>> inf = Inference(model="path/to/model.pt") # Model is an optional argument
43
+ Create an Inference instance with a custom model
44
+ >>> inf = Inference(model="path/to/model.pt")
45
+ >>> inf.inference()
46
+
47
+ Create an Inference instance with default settings
48
+ >>> inf = Inference()
44
49
  >>> inf.inference()
45
50
  """
46
51
 
@@ -62,7 +67,7 @@ class Inference:
62
67
  self.org_frame = None # Container for the original frame display
63
68
  self.ann_frame = None # Container for the annotated frame display
64
69
  self.vid_file_name = None # Video file name or webcam index
65
- self.selected_ind = [] # List of selected class indices for detection
70
+ self.selected_ind: List[int] = [] # List of selected class indices for detection
66
71
  self.model = None # YOLO model instance
67
72
 
68
73
  self.temp_dict = {"model": None, **kwargs}
@@ -73,7 +78,7 @@ class Inference:
73
78
  LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
74
79
 
75
80
  def web_ui(self):
76
- """Sets up the Streamlit web interface with custom HTML elements."""
81
+ """Set up the Streamlit web interface with custom HTML elements."""
77
82
  menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
78
83
 
79
84
  # Main title of streamlit application
@@ -102,7 +107,7 @@ class Inference:
102
107
  "Video",
103
108
  ("webcam", "video"),
104
109
  ) # Add source selection dropdown
105
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
110
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
106
111
  self.conf = float(
107
112
  self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
108
113
  ) # Slider for confidence
@@ -166,7 +171,7 @@ class Inference:
166
171
  break
167
172
 
168
173
  # Process frame with model
169
- if self.enable_trk == "Yes":
174
+ if self.enable_trk:
170
175
  results = self.model.track(
171
176
  frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
172
177
  )
@@ -23,9 +23,9 @@ class TrackZone(BaseSolution):
23
23
  clss (List[int]): Class indices of tracked objects.
24
24
 
25
25
  Methods:
26
- process: Processes each frame of the video, applying region-based tracking.
27
- extract_tracks: Extracts tracking information from the input frame.
28
- display_output: Displays the processed output.
26
+ process: Process each frame of the video, applying region-based tracking.
27
+ extract_tracks: Extract tracking information from the input frame.
28
+ display_output: Display the processed output.
29
29
 
30
30
  Examples:
31
31
  >>> tracker = TrackZone()
@@ -82,7 +82,7 @@ class TrackZone(BaseSolution):
82
82
  )
83
83
 
84
84
  plot_im = annotator.result()
85
- self.display_output(plot_im) # display output with base class function
85
+ self.display_output(plot_im) # Display output with base class function
86
86
 
87
87
  # Return a SolutionResults
88
88
  return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
@@ -2,6 +2,7 @@
2
2
  """Module defines the base classes and structures for object tracking in YOLO."""
3
3
 
4
4
  from collections import OrderedDict
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
 
@@ -66,15 +67,7 @@ class BaseTrack:
66
67
  _count = 0
67
68
 
68
69
  def __init__(self):
69
- """
70
- Initialize a new track with a unique ID and foundational tracking attributes.
71
-
72
- Examples:
73
- Initialize a new track
74
- >>> track = BaseTrack()
75
- >>> print(track.track_id)
76
- 0
77
- """
70
+ """Initialize a new track with a unique ID and foundational tracking attributes."""
78
71
  self.track_id = 0
79
72
  self.is_activated = False
80
73
  self.state = TrackState.New
@@ -88,37 +81,37 @@ class BaseTrack:
88
81
  self.location = (np.inf, np.inf)
89
82
 
90
83
  @property
91
- def end_frame(self):
92
- """Returns the ID of the most recent frame where the object was tracked."""
84
+ def end_frame(self) -> int:
85
+ """Return the ID of the most recent frame where the object was tracked."""
93
86
  return self.frame_id
94
87
 
95
88
  @staticmethod
96
- def next_id():
89
+ def next_id() -> int:
97
90
  """Increment and return the next unique global track ID for object tracking."""
98
91
  BaseTrack._count += 1
99
92
  return BaseTrack._count
100
93
 
101
- def activate(self, *args):
102
- """Activates the track with provided arguments, initializing necessary attributes for tracking."""
94
+ def activate(self, *args: Any) -> None:
95
+ """Activate the track with provided arguments, initializing necessary attributes for tracking."""
103
96
  raise NotImplementedError
104
97
 
105
- def predict(self):
106
- """Predicts the next state of the track based on the current state and tracking model."""
98
+ def predict(self) -> None:
99
+ """Predict the next state of the track based on the current state and tracking model."""
107
100
  raise NotImplementedError
108
101
 
109
- def update(self, *args, **kwargs):
110
- """Updates the track with new observations and data, modifying its state and attributes accordingly."""
102
+ def update(self, *args: Any, **kwargs: Any) -> None:
103
+ """Update the track with new observations and data, modifying its state and attributes accordingly."""
111
104
  raise NotImplementedError
112
105
 
113
- def mark_lost(self):
114
- """Marks the track as lost by updating its state to TrackState.Lost."""
106
+ def mark_lost(self) -> None:
107
+ """Mark the track as lost by updating its state to TrackState.Lost."""
115
108
  self.state = TrackState.Lost
116
109
 
117
- def mark_removed(self):
118
- """Marks the track as removed by setting its state to TrackState.Removed."""
110
+ def mark_removed(self) -> None:
111
+ """Mark the track as removed by setting its state to TrackState.Removed."""
119
112
  self.state = TrackState.Removed
120
113
 
121
114
  @staticmethod
122
- def reset_id():
115
+ def reset_id() -> None:
123
116
  """Reset the global track ID counter to its initial value."""
124
117
  BaseTrack._count = 0
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from collections import deque
4
+ from typing import Any, List, Optional
4
5
 
5
6
  import numpy as np
6
7
  import torch
@@ -51,7 +52,9 @@ class BOTrack(STrack):
51
52
 
52
53
  shared_kalman = KalmanFilterXYWH()
53
54
 
54
- def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
55
+ def __init__(
56
+ self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50
57
+ ):
55
58
  """
56
59
  Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
57
60
 
@@ -59,7 +62,7 @@ class BOTrack(STrack):
59
62
  tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
60
63
  score (float): Confidence score of the detection.
61
64
  cls (int): Class ID of the detected object.
62
- feat (np.ndarray | None): Feature vector associated with the detection.
65
+ feat (np.ndarray, optional): Feature vector associated with the detection.
63
66
  feat_history (int): Maximum length of the feature history deque.
64
67
 
65
68
  Examples:
@@ -79,7 +82,7 @@ class BOTrack(STrack):
79
82
  self.features = deque([], maxlen=feat_history)
80
83
  self.alpha = 0.9
81
84
 
82
- def update_features(self, feat):
85
+ def update_features(self, feat: np.ndarray) -> None:
83
86
  """Update the feature vector and apply exponential moving average smoothing."""
84
87
  feat /= np.linalg.norm(feat)
85
88
  self.curr_feat = feat
@@ -90,7 +93,7 @@ class BOTrack(STrack):
90
93
  self.features.append(feat)
91
94
  self.smooth_feat /= np.linalg.norm(self.smooth_feat)
92
95
 
93
- def predict(self):
96
+ def predict(self) -> None:
94
97
  """Predict the object's future state using the Kalman filter to update its mean and covariance."""
95
98
  mean_state = self.mean.copy()
96
99
  if self.state != TrackState.Tracked:
@@ -99,20 +102,20 @@ class BOTrack(STrack):
99
102
 
100
103
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
101
104
 
102
- def re_activate(self, new_track, frame_id, new_id=False):
105
+ def re_activate(self, new_track: "BOTrack", frame_id: int, new_id: bool = False) -> None:
103
106
  """Reactivate a track with updated features and optionally assign a new ID."""
104
107
  if new_track.curr_feat is not None:
105
108
  self.update_features(new_track.curr_feat)
106
109
  super().re_activate(new_track, frame_id, new_id)
107
110
 
108
- def update(self, new_track, frame_id):
111
+ def update(self, new_track: "BOTrack", frame_id: int) -> None:
109
112
  """Update the track with new detection information and the current frame ID."""
110
113
  if new_track.curr_feat is not None:
111
114
  self.update_features(new_track.curr_feat)
112
115
  super().update(new_track, frame_id)
113
116
 
114
117
  @property
115
- def tlwh(self):
118
+ def tlwh(self) -> np.ndarray:
116
119
  """Return the current bounding box position in `(top left x, top left y, width, height)` format."""
117
120
  if self.mean is None:
118
121
  return self._tlwh.copy()
@@ -121,7 +124,7 @@ class BOTrack(STrack):
121
124
  return ret
122
125
 
123
126
  @staticmethod
124
- def multi_predict(stracks):
127
+ def multi_predict(stracks: List["BOTrack"]) -> None:
125
128
  """Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
126
129
  if len(stracks) <= 0:
127
130
  return
@@ -136,12 +139,12 @@ class BOTrack(STrack):
136
139
  stracks[i].mean = mean
137
140
  stracks[i].covariance = cov
138
141
 
139
- def convert_coords(self, tlwh):
142
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
140
143
  """Convert tlwh bounding box coordinates to xywh format."""
141
144
  return self.tlwh_to_xywh(tlwh)
142
145
 
143
146
  @staticmethod
144
- def tlwh_to_xywh(tlwh):
147
+ def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
145
148
  """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
146
149
  ret = np.asarray(tlwh).copy()
147
150
  ret[:2] += ret[2:] / 2
@@ -176,12 +179,12 @@ class BOTSORT(BYTETracker):
176
179
  The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
177
180
  """
178
181
 
179
- def __init__(self, args, frame_rate=30):
182
+ def __init__(self, args: Any, frame_rate: int = 30):
180
183
  """
181
184
  Initialize BOTSORT object with ReID module and GMC algorithm.
182
185
 
183
186
  Args:
184
- args (object): Parsed command-line arguments containing tracking parameters.
187
+ args (Any): Parsed command-line arguments containing tracking parameters.
185
188
  frame_rate (int): Frame rate of the video being processed.
186
189
 
187
190
  Examples:
@@ -203,11 +206,13 @@ class BOTSORT(BYTETracker):
203
206
  else None
204
207
  )
205
208
 
206
- def get_kalmanfilter(self):
209
+ def get_kalmanfilter(self) -> KalmanFilterXYWH:
207
210
  """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
208
211
  return KalmanFilterXYWH()
209
212
 
210
- def init_track(self, dets, scores, cls, img=None):
213
+ def init_track(
214
+ self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
215
+ ) -> List[BOTrack]:
211
216
  """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
212
217
  if len(dets) == 0:
213
218
  return []
@@ -217,7 +222,7 @@ class BOTSORT(BYTETracker):
217
222
  else:
218
223
  return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
219
224
 
220
- def get_dists(self, tracks, detections):
225
+ def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:
221
226
  """Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
222
227
  dists = matching.iou_distance(tracks, detections)
223
228
  dists_mask = dists > (1 - self.proximity_thresh)
@@ -232,11 +237,11 @@ class BOTSORT(BYTETracker):
232
237
  dists = np.minimum(dists, emb_dists)
233
238
  return dists
234
239
 
235
- def multi_predict(self, tracks):
240
+ def multi_predict(self, tracks: List[BOTrack]) -> None:
236
241
  """Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
237
242
  BOTrack.multi_predict(tracks)
238
243
 
239
- def reset(self):
244
+ def reset(self) -> None:
240
245
  """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
241
246
  super().reset()
242
247
  self.gmc.reset_params()
@@ -245,14 +250,19 @@ class BOTSORT(BYTETracker):
245
250
  class ReID:
246
251
  """YOLO model as encoder for re-identification."""
247
252
 
248
- def __init__(self, model):
249
- """Initialize encoder for re-identification."""
253
+ def __init__(self, model: str):
254
+ """
255
+ Initialize encoder for re-identification.
256
+
257
+ Args:
258
+ model (str): Path to the YOLO model for re-identification.
259
+ """
250
260
  from ultralytics import YOLO
251
261
 
252
262
  self.model = YOLO(model)
253
263
  self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
254
264
 
255
- def __call__(self, img, dets):
265
+ def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
256
266
  """Extract embeddings for detected objects."""
257
267
  feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
258
268
  if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any, List, Optional, Tuple
4
+
3
5
  import numpy as np
4
6
 
5
7
  from ..utils import LOGGER
@@ -29,16 +31,17 @@ class STrack(BaseTrack):
29
31
  idx (int): Index or identifier for the object.
30
32
  frame_id (int): Current frame ID.
31
33
  start_frame (int): Frame where the object was first detected.
34
+ angle (float | None): Optional angle information for oriented bounding boxes.
32
35
 
33
36
  Methods:
34
- predict(): Predict the next state of the object using Kalman filter.
35
- multi_predict(stracks): Predict the next states for multiple tracks.
36
- multi_gmc(stracks, H): Update multiple track states using a homography matrix.
37
- activate(kalman_filter, frame_id): Activate a new tracklet.
38
- re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
39
- update(new_track, frame_id): Update the state of a matched track.
40
- convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
41
- tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
37
+ predict: Predict the next state of the object using Kalman filter.
38
+ multi_predict: Predict the next states for multiple tracks.
39
+ multi_gmc: Update multiple track states using a homography matrix.
40
+ activate: Activate a new tracklet.
41
+ re_activate: Reactivate a previously lost tracklet.
42
+ update: Update the state of a matched track.
43
+ convert_coords: Convert bounding box to x-y-aspect-height format.
44
+ tlwh_to_xyah: Convert tlwh bounding box to xyah format.
42
45
 
43
46
  Examples:
44
47
  Initialize and activate a new track
@@ -48,7 +51,7 @@ class STrack(BaseTrack):
48
51
 
49
52
  shared_kalman = KalmanFilterXYAH()
50
53
 
51
- def __init__(self, xywh, score, cls):
54
+ def __init__(self, xywh: List[float], score: float, cls: Any):
52
55
  """
53
56
  Initialize a new STrack instance.
54
57
 
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
79
82
  self.angle = xywh[4] if len(xywh) == 6 else None
80
83
 
81
84
  def predict(self):
82
- """Predicts the next state (mean and covariance) of the object using the Kalman filter."""
85
+ """Predict the next state (mean and covariance) of the object using the Kalman filter."""
83
86
  mean_state = self.mean.copy()
84
87
  if self.state != TrackState.Tracked:
85
88
  mean_state[7] = 0
86
89
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
87
90
 
88
91
  @staticmethod
89
- def multi_predict(stracks):
92
+ def multi_predict(stracks: List["STrack"]):
90
93
  """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
91
94
  if len(stracks) <= 0:
92
95
  return
@@ -101,7 +104,7 @@ class STrack(BaseTrack):
101
104
  stracks[i].covariance = cov
102
105
 
103
106
  @staticmethod
104
- def multi_gmc(stracks, H=np.eye(2, 3)):
107
+ def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
105
108
  """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
106
109
  if len(stracks) > 0:
107
110
  multi_mean = np.asarray([st.mean.copy() for st in stracks])
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
119
122
  stracks[i].mean = mean
120
123
  stracks[i].covariance = cov
121
124
 
122
- def activate(self, kalman_filter, frame_id):
125
+ def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
123
126
  """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
124
127
  self.kalman_filter = kalman_filter
125
128
  self.track_id = self.next_id()
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
132
135
  self.frame_id = frame_id
133
136
  self.start_frame = frame_id
134
137
 
135
- def re_activate(self, new_track, frame_id, new_id=False):
136
- """Reactivates a previously lost track using new detection data and updates its state and attributes."""
138
+ def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
139
+ """Reactivate a previously lost track using new detection data and update its state and attributes."""
137
140
  self.mean, self.covariance = self.kalman_filter.update(
138
141
  self.mean, self.covariance, self.convert_coords(new_track.tlwh)
139
142
  )
@@ -148,7 +151,7 @@ class STrack(BaseTrack):
148
151
  self.angle = new_track.angle
149
152
  self.idx = new_track.idx
150
153
 
151
- def update(self, new_track, frame_id):
154
+ def update(self, new_track: "STrack", frame_id: int):
152
155
  """
153
156
  Update the state of a matched track.
154
157
 
@@ -177,13 +180,13 @@ class STrack(BaseTrack):
177
180
  self.angle = new_track.angle
178
181
  self.idx = new_track.idx
179
182
 
180
- def convert_coords(self, tlwh):
183
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
181
184
  """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
182
185
  return self.tlwh_to_xyah(tlwh)
183
186
 
184
187
  @property
185
- def tlwh(self):
186
- """Returns the bounding box in top-left-width-height format from the current state estimate."""
188
+ def tlwh(self) -> np.ndarray:
189
+ """Get the bounding box in top-left-width-height format from the current state estimate."""
187
190
  if self.mean is None:
188
191
  return self._tlwh.copy()
189
192
  ret = self.mean[:4].copy()
@@ -192,14 +195,14 @@ class STrack(BaseTrack):
192
195
  return ret
193
196
 
194
197
  @property
195
- def xyxy(self):
196
- """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
198
+ def xyxy(self) -> np.ndarray:
199
+ """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
197
200
  ret = self.tlwh.copy()
198
201
  ret[2:] += ret[:2]
199
202
  return ret
200
203
 
201
204
  @staticmethod
202
- def tlwh_to_xyah(tlwh):
205
+ def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
203
206
  """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
204
207
  ret = np.asarray(tlwh).copy()
205
208
  ret[:2] += ret[2:] / 2
@@ -207,28 +210,28 @@ class STrack(BaseTrack):
207
210
  return ret
208
211
 
209
212
  @property
210
- def xywh(self):
211
- """Returns the current position of the bounding box in (center x, center y, width, height) format."""
213
+ def xywh(self) -> np.ndarray:
214
+ """Get the current position of the bounding box in (center x, center y, width, height) format."""
212
215
  ret = np.asarray(self.tlwh).copy()
213
216
  ret[:2] += ret[2:] / 2
214
217
  return ret
215
218
 
216
219
  @property
217
- def xywha(self):
218
- """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
220
+ def xywha(self) -> np.ndarray:
221
+ """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
219
222
  if self.angle is None:
220
223
  LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
221
224
  return self.xywh
222
225
  return np.concatenate([self.xywh, self.angle[None]])
223
226
 
224
227
  @property
225
- def result(self):
226
- """Returns the current tracking results in the appropriate bounding box format."""
228
+ def result(self) -> List[float]:
229
+ """Get the current tracking results in the appropriate bounding box format."""
227
230
  coords = self.xyxy if self.angle is None else self.xywha
228
231
  return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
229
232
 
230
- def __repr__(self):
231
- """Returns a string representation of the STrack object including start frame, end frame, and track ID."""
233
+ def __repr__(self) -> str:
234
+ """Return a string representation of the STrack object including start frame, end frame, and track ID."""
232
235
  return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
233
236
 
234
237
 
@@ -250,15 +253,16 @@ class BYTETracker:
250
253
  kalman_filter (KalmanFilterXYAH): Kalman Filter object.
251
254
 
252
255
  Methods:
253
- update(results, img=None): Updates object tracker with new detections.
254
- get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
255
- init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
256
- get_dists(tracks, detections): Calculates the distance between tracks and detections.
257
- multi_predict(tracks): Predicts the location of tracks.
258
- reset_id(): Resets the ID counter of STrack.
259
- joint_stracks(tlista, tlistb): Combines two lists of stracks.
260
- sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
261
- remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
256
+ update: Update object tracker with new detections.
257
+ get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
258
+ init_track: Initialize object tracking with detections.
259
+ get_dists: Calculate the distance between tracks and detections.
260
+ multi_predict: Predict the location of tracks.
261
+ reset_id: Reset the ID counter of STrack.
262
+ reset: Reset the tracker by clearing all tracks.
263
+ joint_stracks: Combine two lists of stracks.
264
+ sub_stracks: Filter out the stracks present in the second list from the first list.
265
+ remove_duplicate_stracks: Remove duplicate stracks based on IoU.
262
266
 
263
267
  Examples:
264
268
  Initialize BYTETracker and update with detection results
@@ -267,7 +271,7 @@ class BYTETracker:
267
271
  >>> tracked_objects = tracker.update(results)
268
272
  """
269
273
 
270
- def __init__(self, args, frame_rate=30):
274
+ def __init__(self, args, frame_rate: int = 30):
271
275
  """
272
276
  Initialize a BYTETracker instance for object tracking.
273
277
 
@@ -280,9 +284,9 @@ class BYTETracker:
280
284
  >>> args = Namespace(track_buffer=30)
281
285
  >>> tracker = BYTETracker(args, frame_rate=30)
282
286
  """
283
- self.tracked_stracks = [] # type: list[STrack]
284
- self.lost_stracks = [] # type: list[STrack]
285
- self.removed_stracks = [] # type: list[STrack]
287
+ self.tracked_stracks = [] # type: List[STrack]
288
+ self.lost_stracks = [] # type: List[STrack]
289
+ self.removed_stracks = [] # type: List[STrack]
286
290
 
287
291
  self.frame_id = 0
288
292
  self.args = args
@@ -290,8 +294,8 @@ class BYTETracker:
290
294
  self.kalman_filter = self.get_kalmanfilter()
291
295
  self.reset_id()
292
296
 
293
- def update(self, results, img=None, feats=None):
294
- """Updates the tracker with new detections and returns the current list of tracked objects."""
297
+ def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:
298
+ """Update the tracker with new detections and return the current list of tracked objects."""
295
299
  self.frame_id += 1
296
300
  activated_stracks = []
297
301
  refind_stracks = []
@@ -319,7 +323,7 @@ class BYTETracker:
319
323
  detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)
320
324
  # Add newly detected tracklets to tracked_stracks
321
325
  unconfirmed = []
322
- tracked_stracks = [] # type: list[STrack]
326
+ tracked_stracks = [] # type: List[STrack]
323
327
  for track in self.tracked_stracks:
324
328
  if not track.is_activated:
325
329
  unconfirmed.append(track)
@@ -408,42 +412,44 @@ class BYTETracker:
408
412
 
409
413
  return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
410
414
 
411
- def get_kalmanfilter(self):
412
- """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
415
+ def get_kalmanfilter(self) -> KalmanFilterXYAH:
416
+ """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
413
417
  return KalmanFilterXYAH()
414
418
 
415
- def init_track(self, dets, scores, cls, img=None):
416
- """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
419
+ def init_track(
420
+ self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
421
+ ) -> List[STrack]:
422
+ """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
417
423
  return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
418
424
 
419
- def get_dists(self, tracks, detections):
420
- """Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
425
+ def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:
426
+ """Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
421
427
  dists = matching.iou_distance(tracks, detections)
422
428
  if self.args.fuse_score:
423
429
  dists = matching.fuse_score(dists, detections)
424
430
  return dists
425
431
 
426
- def multi_predict(self, tracks):
432
+ def multi_predict(self, tracks: List[STrack]):
427
433
  """Predict the next states for multiple tracks using Kalman filter."""
428
434
  STrack.multi_predict(tracks)
429
435
 
430
436
  @staticmethod
431
437
  def reset_id():
432
- """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
438
+ """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
433
439
  STrack.reset_id()
434
440
 
435
441
  def reset(self):
436
- """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
437
- self.tracked_stracks = [] # type: list[STrack]
438
- self.lost_stracks = [] # type: list[STrack]
439
- self.removed_stracks = [] # type: list[STrack]
442
+ """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
443
+ self.tracked_stracks = [] # type: List[STrack]
444
+ self.lost_stracks = [] # type: List[STrack]
445
+ self.removed_stracks = [] # type: List[STrack]
440
446
  self.frame_id = 0
441
447
  self.kalman_filter = self.get_kalmanfilter()
442
448
  self.reset_id()
443
449
 
444
450
  @staticmethod
445
- def joint_stracks(tlista, tlistb):
446
- """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
451
+ def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
452
+ """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
447
453
  exists = {}
448
454
  res = []
449
455
  for t in tlista:
@@ -457,14 +463,14 @@ class BYTETracker:
457
463
  return res
458
464
 
459
465
  @staticmethod
460
- def sub_stracks(tlista, tlistb):
461
- """Filters out the stracks present in the second list from the first list."""
466
+ def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
467
+ """Filter out the stracks present in the second list from the first list."""
462
468
  track_ids_b = {t.track_id for t in tlistb}
463
469
  return [t for t in tlista if t.track_id not in track_ids_b]
464
470
 
465
471
  @staticmethod
466
- def remove_duplicate_stracks(stracksa, stracksb):
467
- """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
472
+ def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:
473
+ """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
468
474
  pdist = matching.iou_distance(stracksa, stracksb)
469
475
  pairs = np.where(pdist < 0.15)
470
476
  dupa, dupb = [], []