ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any, List, Optional, Tuple
4
+
3
5
  import numpy as np
4
6
 
5
7
  from ..utils import LOGGER
@@ -29,16 +31,17 @@ class STrack(BaseTrack):
29
31
  idx (int): Index or identifier for the object.
30
32
  frame_id (int): Current frame ID.
31
33
  start_frame (int): Frame where the object was first detected.
34
+ angle (float | None): Optional angle information for oriented bounding boxes.
32
35
 
33
36
  Methods:
34
- predict(): Predict the next state of the object using Kalman filter.
35
- multi_predict(stracks): Predict the next states for multiple tracks.
36
- multi_gmc(stracks, H): Update multiple track states using a homography matrix.
37
- activate(kalman_filter, frame_id): Activate a new tracklet.
38
- re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
39
- update(new_track, frame_id): Update the state of a matched track.
40
- convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
41
- tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
37
+ predict: Predict the next state of the object using Kalman filter.
38
+ multi_predict: Predict the next states for multiple tracks.
39
+ multi_gmc: Update multiple track states using a homography matrix.
40
+ activate: Activate a new tracklet.
41
+ re_activate: Reactivate a previously lost tracklet.
42
+ update: Update the state of a matched track.
43
+ convert_coords: Convert bounding box to x-y-aspect-height format.
44
+ tlwh_to_xyah: Convert tlwh bounding box to xyah format.
42
45
 
43
46
  Examples:
44
47
  Initialize and activate a new track
@@ -48,7 +51,7 @@ class STrack(BaseTrack):
48
51
 
49
52
  shared_kalman = KalmanFilterXYAH()
50
53
 
51
- def __init__(self, xywh, score, cls):
54
+ def __init__(self, xywh: List[float], score: float, cls: Any):
52
55
  """
53
56
  Initialize a new STrack instance.
54
57
 
@@ -79,14 +82,14 @@ class STrack(BaseTrack):
79
82
  self.angle = xywh[4] if len(xywh) == 6 else None
80
83
 
81
84
  def predict(self):
82
- """Predicts the next state (mean and covariance) of the object using the Kalman filter."""
85
+ """Predict the next state (mean and covariance) of the object using the Kalman filter."""
83
86
  mean_state = self.mean.copy()
84
87
  if self.state != TrackState.Tracked:
85
88
  mean_state[7] = 0
86
89
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
87
90
 
88
91
  @staticmethod
89
- def multi_predict(stracks):
92
+ def multi_predict(stracks: List["STrack"]):
90
93
  """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
91
94
  if len(stracks) <= 0:
92
95
  return
@@ -101,7 +104,7 @@ class STrack(BaseTrack):
101
104
  stracks[i].covariance = cov
102
105
 
103
106
  @staticmethod
104
- def multi_gmc(stracks, H=np.eye(2, 3)):
107
+ def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
105
108
  """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
106
109
  if len(stracks) > 0:
107
110
  multi_mean = np.asarray([st.mean.copy() for st in stracks])
@@ -119,7 +122,7 @@ class STrack(BaseTrack):
119
122
  stracks[i].mean = mean
120
123
  stracks[i].covariance = cov
121
124
 
122
- def activate(self, kalman_filter, frame_id):
125
+ def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
123
126
  """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
124
127
  self.kalman_filter = kalman_filter
125
128
  self.track_id = self.next_id()
@@ -132,8 +135,8 @@ class STrack(BaseTrack):
132
135
  self.frame_id = frame_id
133
136
  self.start_frame = frame_id
134
137
 
135
- def re_activate(self, new_track, frame_id, new_id=False):
136
- """Reactivates a previously lost track using new detection data and updates its state and attributes."""
138
+ def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
139
+ """Reactivate a previously lost track using new detection data and update its state and attributes."""
137
140
  self.mean, self.covariance = self.kalman_filter.update(
138
141
  self.mean, self.covariance, self.convert_coords(new_track.tlwh)
139
142
  )
@@ -148,7 +151,7 @@ class STrack(BaseTrack):
148
151
  self.angle = new_track.angle
149
152
  self.idx = new_track.idx
150
153
 
151
- def update(self, new_track, frame_id):
154
+ def update(self, new_track: "STrack", frame_id: int):
152
155
  """
153
156
  Update the state of a matched track.
154
157
 
@@ -177,13 +180,13 @@ class STrack(BaseTrack):
177
180
  self.angle = new_track.angle
178
181
  self.idx = new_track.idx
179
182
 
180
- def convert_coords(self, tlwh):
183
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
181
184
  """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
182
185
  return self.tlwh_to_xyah(tlwh)
183
186
 
184
187
  @property
185
- def tlwh(self):
186
- """Returns the bounding box in top-left-width-height format from the current state estimate."""
188
+ def tlwh(self) -> np.ndarray:
189
+ """Get the bounding box in top-left-width-height format from the current state estimate."""
187
190
  if self.mean is None:
188
191
  return self._tlwh.copy()
189
192
  ret = self.mean[:4].copy()
@@ -192,14 +195,14 @@ class STrack(BaseTrack):
192
195
  return ret
193
196
 
194
197
  @property
195
- def xyxy(self):
196
- """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
198
+ def xyxy(self) -> np.ndarray:
199
+ """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
197
200
  ret = self.tlwh.copy()
198
201
  ret[2:] += ret[:2]
199
202
  return ret
200
203
 
201
204
  @staticmethod
202
- def tlwh_to_xyah(tlwh):
205
+ def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
203
206
  """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
204
207
  ret = np.asarray(tlwh).copy()
205
208
  ret[:2] += ret[2:] / 2
@@ -207,28 +210,28 @@ class STrack(BaseTrack):
207
210
  return ret
208
211
 
209
212
  @property
210
- def xywh(self):
211
- """Returns the current position of the bounding box in (center x, center y, width, height) format."""
213
+ def xywh(self) -> np.ndarray:
214
+ """Get the current position of the bounding box in (center x, center y, width, height) format."""
212
215
  ret = np.asarray(self.tlwh).copy()
213
216
  ret[:2] += ret[2:] / 2
214
217
  return ret
215
218
 
216
219
  @property
217
- def xywha(self):
218
- """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
220
+ def xywha(self) -> np.ndarray:
221
+ """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
219
222
  if self.angle is None:
220
223
  LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
221
224
  return self.xywh
222
225
  return np.concatenate([self.xywh, self.angle[None]])
223
226
 
224
227
  @property
225
- def result(self):
226
- """Returns the current tracking results in the appropriate bounding box format."""
228
+ def result(self) -> List[float]:
229
+ """Get the current tracking results in the appropriate bounding box format."""
227
230
  coords = self.xyxy if self.angle is None else self.xywha
228
231
  return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
229
232
 
230
- def __repr__(self):
231
- """Returns a string representation of the STrack object including start frame, end frame, and track ID."""
233
+ def __repr__(self) -> str:
234
+ """Return a string representation of the STrack object including start frame, end frame, and track ID."""
232
235
  return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
233
236
 
234
237
 
@@ -250,15 +253,16 @@ class BYTETracker:
250
253
  kalman_filter (KalmanFilterXYAH): Kalman Filter object.
251
254
 
252
255
  Methods:
253
- update(results, img=None): Updates object tracker with new detections.
254
- get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
255
- init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
256
- get_dists(tracks, detections): Calculates the distance between tracks and detections.
257
- multi_predict(tracks): Predicts the location of tracks.
258
- reset_id(): Resets the ID counter of STrack.
259
- joint_stracks(tlista, tlistb): Combines two lists of stracks.
260
- sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
261
- remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
256
+ update: Update object tracker with new detections.
257
+ get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
258
+ init_track: Initialize object tracking with detections.
259
+ get_dists: Calculate the distance between tracks and detections.
260
+ multi_predict: Predict the location of tracks.
261
+ reset_id: Reset the ID counter of STrack.
262
+ reset: Reset the tracker by clearing all tracks.
263
+ joint_stracks: Combine two lists of stracks.
264
+ sub_stracks: Filter out the stracks present in the second list from the first list.
265
+ remove_duplicate_stracks: Remove duplicate stracks based on IoU.
262
266
 
263
267
  Examples:
264
268
  Initialize BYTETracker and update with detection results
@@ -267,7 +271,7 @@ class BYTETracker:
267
271
  >>> tracked_objects = tracker.update(results)
268
272
  """
269
273
 
270
- def __init__(self, args, frame_rate=30):
274
+ def __init__(self, args, frame_rate: int = 30):
271
275
  """
272
276
  Initialize a BYTETracker instance for object tracking.
273
277
 
@@ -280,9 +284,9 @@ class BYTETracker:
280
284
  >>> args = Namespace(track_buffer=30)
281
285
  >>> tracker = BYTETracker(args, frame_rate=30)
282
286
  """
283
- self.tracked_stracks = [] # type: list[STrack]
284
- self.lost_stracks = [] # type: list[STrack]
285
- self.removed_stracks = [] # type: list[STrack]
287
+ self.tracked_stracks = [] # type: List[STrack]
288
+ self.lost_stracks = [] # type: List[STrack]
289
+ self.removed_stracks = [] # type: List[STrack]
286
290
 
287
291
  self.frame_id = 0
288
292
  self.args = args
@@ -290,8 +294,8 @@ class BYTETracker:
290
294
  self.kalman_filter = self.get_kalmanfilter()
291
295
  self.reset_id()
292
296
 
293
- def update(self, results, img=None, feats=None):
294
- """Updates the tracker with new detections and returns the current list of tracked objects."""
297
+ def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:
298
+ """Update the tracker with new detections and return the current list of tracked objects."""
295
299
  self.frame_id += 1
296
300
  activated_stracks = []
297
301
  refind_stracks = []
@@ -319,7 +323,7 @@ class BYTETracker:
319
323
  detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)
320
324
  # Add newly detected tracklets to tracked_stracks
321
325
  unconfirmed = []
322
- tracked_stracks = [] # type: list[STrack]
326
+ tracked_stracks = [] # type: List[STrack]
323
327
  for track in self.tracked_stracks:
324
328
  if not track.is_activated:
325
329
  unconfirmed.append(track)
@@ -408,42 +412,44 @@ class BYTETracker:
408
412
 
409
413
  return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
410
414
 
411
- def get_kalmanfilter(self):
412
- """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
415
+ def get_kalmanfilter(self) -> KalmanFilterXYAH:
416
+ """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
413
417
  return KalmanFilterXYAH()
414
418
 
415
- def init_track(self, dets, scores, cls, img=None):
416
- """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
419
+ def init_track(
420
+ self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
421
+ ) -> List[STrack]:
422
+ """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
417
423
  return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
418
424
 
419
- def get_dists(self, tracks, detections):
420
- """Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
425
+ def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:
426
+ """Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
421
427
  dists = matching.iou_distance(tracks, detections)
422
428
  if self.args.fuse_score:
423
429
  dists = matching.fuse_score(dists, detections)
424
430
  return dists
425
431
 
426
- def multi_predict(self, tracks):
432
+ def multi_predict(self, tracks: List[STrack]):
427
433
  """Predict the next states for multiple tracks using Kalman filter."""
428
434
  STrack.multi_predict(tracks)
429
435
 
430
436
  @staticmethod
431
437
  def reset_id():
432
- """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
438
+ """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
433
439
  STrack.reset_id()
434
440
 
435
441
  def reset(self):
436
- """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
437
- self.tracked_stracks = [] # type: list[STrack]
438
- self.lost_stracks = [] # type: list[STrack]
439
- self.removed_stracks = [] # type: list[STrack]
442
+ """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
443
+ self.tracked_stracks = [] # type: List[STrack]
444
+ self.lost_stracks = [] # type: List[STrack]
445
+ self.removed_stracks = [] # type: List[STrack]
440
446
  self.frame_id = 0
441
447
  self.kalman_filter = self.get_kalmanfilter()
442
448
  self.reset_id()
443
449
 
444
450
  @staticmethod
445
- def joint_stracks(tlista, tlistb):
446
- """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
451
+ def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
452
+ """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
447
453
  exists = {}
448
454
  res = []
449
455
  for t in tlista:
@@ -457,14 +463,14 @@ class BYTETracker:
457
463
  return res
458
464
 
459
465
  @staticmethod
460
- def sub_stracks(tlista, tlistb):
461
- """Filters out the stracks present in the second list from the first list."""
466
+ def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:
467
+ """Filter out the stracks present in the second list from the first list."""
462
468
  track_ids_b = {t.track_id for t in tlistb}
463
469
  return [t for t in tlista if t.track_id not in track_ids_b]
464
470
 
465
471
  @staticmethod
466
- def remove_duplicate_stracks(stracksa, stracksb):
467
- """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
472
+ def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:
473
+ """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
468
474
  pdist = matching.iou_distance(stracksa, stracksb)
469
475
  pairs = np.where(pdist < 0.15)
470
476
  dupa, dupb = [], []
@@ -20,15 +20,11 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
20
20
  Initialize trackers for object tracking during prediction.
21
21
 
22
22
  Args:
23
- predictor (object): The predictor object to initialize trackers for.
24
- persist (bool): Whether to persist the trackers if they already exist.
25
-
26
- Raises:
27
- AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
28
- ValueError: If the task is 'classify' as classification doesn't support tracking.
23
+ predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
24
+ persist (bool, optional): Whether to persist the trackers if they already exist.
29
25
 
30
26
  Examples:
31
- Initialize trackers for a predictor object:
27
+ Initialize trackers for a predictor object
32
28
  >>> predictor = SomePredictorClass()
33
29
  >>> on_predict_start(predictor, persist=True)
34
30
  """
@@ -79,7 +75,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
79
75
 
80
76
  Args:
81
77
  predictor (object): The predictor object containing the predictions.
82
- persist (bool): Whether to persist the trackers if they already exist.
78
+ persist (bool, optional): Whether to persist the trackers if they already exist.
83
79
 
84
80
  Examples:
85
81
  Postprocess predictions and update with tracking
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import copy
4
+ from typing import List, Optional
4
5
 
5
6
  import cv2
6
7
  import numpy as np
@@ -19,7 +20,7 @@ class GMC:
19
20
  method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
20
21
  downscale (int): Factor by which to downscale the frames for processing.
21
22
  prevFrame (np.ndarray): Previous frame for tracking.
22
- prevKeyPoints (list): Keypoints from the previous frame.
23
+ prevKeyPoints (List): Keypoints from the previous frame.
23
24
  prevDescriptors (np.ndarray): Descriptors from the previous frame.
24
25
  initializedFirstFrame (bool): Flag indicating if the first frame has been processed.
25
26
 
@@ -88,13 +89,13 @@ class GMC:
88
89
  self.prevDescriptors = None
89
90
  self.initializedFirstFrame = False
90
91
 
91
- def apply(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray:
92
+ def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
92
93
  """
93
94
  Apply object detection on a raw frame using the specified method.
94
95
 
95
96
  Args:
96
97
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
97
- detections (List | None): List of detections to be used in the processing.
98
+ detections (List, optional): List of detections to be used in the processing.
98
99
 
99
100
  Returns:
100
101
  (np.ndarray): Transformation matrix with shape (2, 3).
@@ -136,23 +137,18 @@ class GMC:
136
137
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
137
138
  H = np.eye(2, 3, dtype=np.float32)
138
139
 
139
- # Downscale image
140
+ # Downscale image for computational efficiency
140
141
  if self.downscale > 1.0:
141
142
  frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
142
143
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
143
144
 
144
- # Handle first frame
145
+ # Handle first frame initialization
145
146
  if not self.initializedFirstFrame:
146
- # Initialize data
147
147
  self.prevFrame = frame.copy()
148
-
149
- # Initialization done
150
148
  self.initializedFirstFrame = True
151
-
152
149
  return H
153
150
 
154
- # Run the ECC algorithm. The results are stored in warp_matrix.
155
- # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
151
+ # Run the ECC algorithm to find transformation matrix
156
152
  try:
157
153
  (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
158
154
  except Exception as e:
@@ -160,13 +156,13 @@ class GMC:
160
156
 
161
157
  return H
162
158
 
163
- def apply_features(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray:
159
+ def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:
164
160
  """
165
161
  Apply feature-based methods like ORB or SIFT to a raw frame.
166
162
 
167
163
  Args:
168
164
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
169
- detections (List | None): List of detections to be used in the processing.
165
+ detections (List, optional): List of detections to be used in the processing.
170
166
 
171
167
  Returns:
172
168
  (np.ndarray): Transformation matrix with shape (2, 3).
@@ -182,55 +178,50 @@ class GMC:
182
178
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
183
179
  H = np.eye(2, 3)
184
180
 
185
- # Downscale image
181
+ # Downscale image for computational efficiency
186
182
  if self.downscale > 1.0:
187
183
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
188
184
  width = width // self.downscale
189
185
  height = height // self.downscale
190
186
 
191
- # Find the keypoints
187
+ # Create mask for keypoint detection, excluding border regions
192
188
  mask = np.zeros_like(frame)
193
189
  mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
190
+
191
+ # Exclude detection regions from mask to avoid tracking detected objects
194
192
  if detections is not None:
195
193
  for det in detections:
196
194
  tlbr = (det[:4] / self.downscale).astype(np.int_)
197
195
  mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
198
196
 
197
+ # Find keypoints and compute descriptors
199
198
  keypoints = self.detector.detect(frame, mask)
200
-
201
- # Compute the descriptors
202
199
  keypoints, descriptors = self.extractor.compute(frame, keypoints)
203
200
 
204
- # Handle first frame
201
+ # Handle first frame initialization
205
202
  if not self.initializedFirstFrame:
206
- # Initialize data
207
203
  self.prevFrame = frame.copy()
208
204
  self.prevKeyPoints = copy.copy(keypoints)
209
205
  self.prevDescriptors = copy.copy(descriptors)
210
-
211
- # Initialization done
212
206
  self.initializedFirstFrame = True
213
-
214
207
  return H
215
208
 
216
- # Match descriptors
209
+ # Match descriptors between previous and current frame
217
210
  knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
218
211
 
219
- # Filter matches based on smallest spatial distance
212
+ # Filter matches based on spatial distance constraints
220
213
  matches = []
221
214
  spatialDistances = []
222
-
223
215
  maxSpatialDistance = 0.25 * np.array([width, height])
224
216
 
225
217
  # Handle empty matches case
226
218
  if len(knnMatches) == 0:
227
- # Store to next iteration
228
219
  self.prevFrame = frame.copy()
229
220
  self.prevKeyPoints = copy.copy(keypoints)
230
221
  self.prevDescriptors = copy.copy(descriptors)
231
-
232
222
  return H
233
223
 
224
+ # Apply Lowe's ratio test and spatial distance filtering
234
225
  for m, n in knnMatches:
235
226
  if m.distance < 0.9 * n.distance:
236
227
  prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
@@ -247,11 +238,12 @@ class GMC:
247
238
  spatialDistances.append(spatialDistance)
248
239
  matches.append(m)
249
240
 
241
+ # Filter outliers using statistical analysis
250
242
  meanSpatialDistances = np.mean(spatialDistances, 0)
251
243
  stdSpatialDistances = np.std(spatialDistances, 0)
252
-
253
244
  inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
254
245
 
246
+ # Extract good matches and corresponding points
255
247
  goodMatches = []
256
248
  prevPoints = []
257
249
  currPoints = []
@@ -264,39 +256,18 @@ class GMC:
264
256
  prevPoints = np.array(prevPoints)
265
257
  currPoints = np.array(currPoints)
266
258
 
267
- # Draw the keypoint matches on the output image
268
- # if False:
269
- # import matplotlib.pyplot as plt
270
- # matches_img = np.hstack((self.prevFrame, frame))
271
- # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
272
- # W = self.prevFrame.shape[1]
273
- # for m in goodMatches:
274
- # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
275
- # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
276
- # curr_pt[0] += W
277
- # color = np.random.randint(0, 255, 3)
278
- # color = (int(color[0]), int(color[1]), int(color[2]))
279
- #
280
- # matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
281
- # matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
282
- # matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
283
- #
284
- # plt.figure()
285
- # plt.imshow(matches_img)
286
- # plt.show()
287
-
288
- # Find rigid matrix
259
+ # Estimate transformation matrix using RANSAC
289
260
  if prevPoints.shape[0] > 4:
290
261
  H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
291
262
 
292
- # Handle downscale
263
+ # Scale translation components back to original resolution
293
264
  if self.downscale > 1.0:
294
265
  H[0, 2] *= self.downscale
295
266
  H[1, 2] *= self.downscale
296
267
  else:
297
268
  LOGGER.warning("not enough matching points")
298
269
 
299
- # Store to next iteration
270
+ # Store current frame data for next iteration
300
271
  self.prevFrame = frame.copy()
301
272
  self.prevKeyPoints = copy.copy(keypoints)
302
273
  self.prevDescriptors = copy.copy(descriptors)
@@ -324,24 +295,24 @@ class GMC:
324
295
  frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
325
296
  H = np.eye(2, 3)
326
297
 
327
- # Downscale image
298
+ # Downscale image for computational efficiency
328
299
  if self.downscale > 1.0:
329
300
  frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
330
301
 
331
- # Find the keypoints
302
+ # Find good features to track
332
303
  keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
333
304
 
334
- # Handle first frame
305
+ # Handle first frame initialization
335
306
  if not self.initializedFirstFrame or self.prevKeyPoints is None:
336
307
  self.prevFrame = frame.copy()
337
308
  self.prevKeyPoints = copy.copy(keypoints)
338
309
  self.initializedFirstFrame = True
339
310
  return H
340
311
 
341
- # Find correspondences
312
+ # Calculate optical flow using Lucas-Kanade method
342
313
  matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
343
314
 
344
- # Leave good correspondences only
315
+ # Extract successfully tracked points
345
316
  prevPoints = []
346
317
  currPoints = []
347
318
 
@@ -353,16 +324,18 @@ class GMC:
353
324
  prevPoints = np.array(prevPoints)
354
325
  currPoints = np.array(currPoints)
355
326
 
356
- # Find rigid matrix
327
+ # Estimate transformation matrix using RANSAC
357
328
  if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):
358
329
  H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
359
330
 
331
+ # Scale translation components back to original resolution
360
332
  if self.downscale > 1.0:
361
333
  H[0, 2] *= self.downscale
362
334
  H[1, 2] *= self.downscale
363
335
  else:
364
336
  LOGGER.warning("not enough matching points")
365
337
 
338
+ # Store current frame data for next iteration
366
339
  self.prevFrame = frame.copy()
367
340
  self.prevKeyPoints = copy.copy(keypoints)
368
341