ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -12,24 +12,29 @@ class SpeedEstimator(BaseSolution):
12
12
  A class to estimate the speed of objects in a real-time video stream based on their tracks.
13
13
 
14
14
  This class extends the BaseSolution class and provides functionality for estimating object speeds using
15
- tracking data in video streams.
15
+ tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
16
+ to real-world units using a configurable meters-per-pixel scale factor.
16
17
 
17
18
  Attributes:
18
- spd (Dict[int, float]): Dictionary storing speed data for tracked objects.
19
- trk_hist (Dict[int, float]): Dictionary storing the object tracking data.
20
- max_hist (int): maximum track history before computing speed
21
- meters_per_pixel (float): Real-world meters represented by one pixel (e.g., 0.04 for 4m over 100px).
22
- max_speed (int): Maximum allowed object speed; values above this will be capped at 120 km/h.
19
+ fps (float): Video frame rate for time calculations.
20
+ frame_count (int): Global frame counter for tracking temporal information.
21
+ trk_frame_ids (dict): Maps track IDs to their first frame index.
22
+ spd (dict): Final speed per object in km/h once locked.
23
+ trk_hist (dict): Maps track IDs to deque of position history.
24
+ locked_ids (set): Track IDs whose speed has been finalized.
25
+ max_hist (int): Required frame history before computing speed.
26
+ meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
27
+ max_speed (int): Maximum allowed object speed; values above this will be capped.
23
28
 
24
29
  Methods:
25
- initialize_region: Initializes the speed estimation region.
26
- process: Processes input frames to estimate object speeds.
27
- store_tracking_history: Stores the tracking history for an object.
28
- extract_tracks: Extracts tracks from the current frame.
29
- display_output: Displays the output with annotations.
30
+ process: Process input frames to estimate object speeds based on tracking data.
31
+ store_tracking_history: Store the tracking history for an object.
32
+ extract_tracks: Extract tracks from the current frame.
33
+ display_output: Display the output with annotations.
30
34
 
31
35
  Examples:
32
- >>> estimator = SpeedEstimator()
36
+ Initialize speed estimator and process a frame
37
+ >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
33
38
  >>> frame = cv2.imread("frame.jpg")
34
39
  >>> results = estimator.process(frame)
35
40
  >>> cv2.imshow("Speed Estimation", results.plot_im)
@@ -44,15 +49,15 @@ class SpeedEstimator(BaseSolution):
44
49
  """
45
50
  super().__init__(**kwargs)
46
51
 
47
- self.fps = self.CFG["fps"] # assumed video FPS
48
- self.frame_count = 0 # global frame count
52
+ self.fps = self.CFG["fps"] # Video frame rate for time calculations
53
+ self.frame_count = 0 # Global frame counter
49
54
  self.trk_frame_ids = {} # Track ID → first frame index
50
55
  self.spd = {} # Final speed per object (km/h), once locked
51
56
  self.trk_hist = {} # Track ID → deque of (time, position)
52
57
  self.locked_ids = set() # Track IDs whose speed has been finalized
53
58
  self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
54
59
  self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
55
- self.max_speed = self.CFG["max_speed"] # max_speed adjustment
60
+ self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
56
61
 
57
62
  def process(self, im0):
58
63
  """
@@ -65,6 +70,7 @@ class SpeedEstimator(BaseSolution):
65
70
  (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
66
71
 
67
72
  Examples:
73
+ Process a frame for speed estimation
68
74
  >>> estimator = SpeedEstimator()
69
75
  >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
70
76
  >>> results = estimator.process(image)
@@ -89,15 +95,15 @@ class SpeedEstimator(BaseSolution):
89
95
  p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
90
96
  dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
91
97
  if dt > 0:
92
- dx, dy = p1[0] - p0[0], p1[1] - p0[1] # pixel displacement
93
- pixel_distance = sqrt(dx * dx + dy * dy) # get pixel distance
94
- meters = pixel_distance * self.meter_per_pixel # convert to meters
98
+ dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
99
+ pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
100
+ meters = pixel_distance * self.meter_per_pixel # Convert to meters
95
101
  self.spd[track_id] = int(
96
102
  min((meters / dt) * 3.6, self.max_speed)
97
- ) # convert to km/h and store final speed
98
- self.locked_ids.add(track_id) # prevent further updates
99
- self.trk_hist.pop(track_id, None) # free memory
100
- self.trk_frame_ids.pop(track_id, None) # optional: remove frame start too
103
+ ) # Convert to km/h and store final speed
104
+ self.locked_ids.add(track_id) # Prevent further updates
105
+ self.trk_hist.pop(track_id, None) # Free memory
106
+ self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
101
107
 
102
108
  if track_id in self.spd:
103
109
  speed_label = f"{self.spd[track_id]} km/h"
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import io
4
- from typing import Any
4
+ from typing import Any, List
5
5
 
6
6
  import cv2
7
7
 
@@ -24,7 +24,7 @@ class Inference:
24
24
  model_path (str): Path to the loaded model.
25
25
  model (YOLO): The YOLO model instance.
26
26
  source (str): Selected video source (webcam or video file).
27
- enable_trk (str): Enable tracking option ("Yes" or "No").
27
+ enable_trk (bool): Enable tracking option.
28
28
  conf (float): Confidence threshold for detection.
29
29
  iou (float): IoU threshold for non-maximum suppression.
30
30
  org_frame (Any): Container for the original frame to be displayed.
@@ -33,14 +33,19 @@ class Inference:
33
33
  selected_ind (List[int]): List of selected class indices for detection.
34
34
 
35
35
  Methods:
36
- web_ui: Sets up the Streamlit web interface with custom HTML elements.
37
- sidebar: Configures the Streamlit sidebar for model and inference settings.
38
- source_upload: Handles video file uploads through the Streamlit interface.
39
- configure: Configures the model and loads selected classes for inference.
40
- inference: Performs real-time object detection inference.
36
+ web_ui: Set up the Streamlit web interface with custom HTML elements.
37
+ sidebar: Configure the Streamlit sidebar for model and inference settings.
38
+ source_upload: Handle video file uploads through the Streamlit interface.
39
+ configure: Configure the model and load selected classes for inference.
40
+ inference: Perform real-time object detection inference.
41
41
 
42
42
  Examples:
43
- >>> inf = Inference(model="path/to/model.pt") # Model is an optional argument
43
+ Create an Inference instance with a custom model
44
+ >>> inf = Inference(model="path/to/model.pt")
45
+ >>> inf.inference()
46
+
47
+ Create an Inference instance with default settings
48
+ >>> inf = Inference()
44
49
  >>> inf.inference()
45
50
  """
46
51
 
@@ -62,7 +67,7 @@ class Inference:
62
67
  self.org_frame = None # Container for the original frame display
63
68
  self.ann_frame = None # Container for the annotated frame display
64
69
  self.vid_file_name = None # Video file name or webcam index
65
- self.selected_ind = [] # List of selected class indices for detection
70
+ self.selected_ind: List[int] = [] # List of selected class indices for detection
66
71
  self.model = None # YOLO model instance
67
72
 
68
73
  self.temp_dict = {"model": None, **kwargs}
@@ -73,7 +78,7 @@ class Inference:
73
78
  LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
74
79
 
75
80
  def web_ui(self):
76
- """Sets up the Streamlit web interface with custom HTML elements."""
81
+ """Set up the Streamlit web interface with custom HTML elements."""
77
82
  menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
78
83
 
79
84
  # Main title of streamlit application
@@ -102,7 +107,7 @@ class Inference:
102
107
  "Video",
103
108
  ("webcam", "video"),
104
109
  ) # Add source selection dropdown
105
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
110
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
106
111
  self.conf = float(
107
112
  self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
108
113
  ) # Slider for confidence
@@ -166,7 +171,7 @@ class Inference:
166
171
  break
167
172
 
168
173
  # Process frame with model
169
- if self.enable_trk == "Yes":
174
+ if self.enable_trk:
170
175
  results = self.model.track(
171
176
  frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
172
177
  )
@@ -23,9 +23,9 @@ class TrackZone(BaseSolution):
23
23
  clss (List[int]): Class indices of tracked objects.
24
24
 
25
25
  Methods:
26
- process: Processes each frame of the video, applying region-based tracking.
27
- extract_tracks: Extracts tracking information from the input frame.
28
- display_output: Displays the processed output.
26
+ process: Process each frame of the video, applying region-based tracking.
27
+ extract_tracks: Extract tracking information from the input frame.
28
+ display_output: Display the processed output.
29
29
 
30
30
  Examples:
31
31
  >>> tracker = TrackZone()
@@ -82,7 +82,7 @@ class TrackZone(BaseSolution):
82
82
  )
83
83
 
84
84
  plot_im = annotator.result()
85
- self.display_output(plot_im) # display output with base class function
85
+ self.display_output(plot_im) # Display output with base class function
86
86
 
87
87
  # Return a SolutionResults
88
88
  return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
@@ -2,6 +2,7 @@
2
2
  """Module defines the base classes and structures for object tracking in YOLO."""
3
3
 
4
4
  from collections import OrderedDict
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
 
@@ -66,15 +67,7 @@ class BaseTrack:
66
67
  _count = 0
67
68
 
68
69
  def __init__(self):
69
- """
70
- Initialize a new track with a unique ID and foundational tracking attributes.
71
-
72
- Examples:
73
- Initialize a new track
74
- >>> track = BaseTrack()
75
- >>> print(track.track_id)
76
- 0
77
- """
70
+ """Initialize a new track with a unique ID and foundational tracking attributes."""
78
71
  self.track_id = 0
79
72
  self.is_activated = False
80
73
  self.state = TrackState.New
@@ -88,37 +81,37 @@ class BaseTrack:
88
81
  self.location = (np.inf, np.inf)
89
82
 
90
83
  @property
91
- def end_frame(self):
92
- """Returns the ID of the most recent frame where the object was tracked."""
84
+ def end_frame(self) -> int:
85
+ """Return the ID of the most recent frame where the object was tracked."""
93
86
  return self.frame_id
94
87
 
95
88
  @staticmethod
96
- def next_id():
89
+ def next_id() -> int:
97
90
  """Increment and return the next unique global track ID for object tracking."""
98
91
  BaseTrack._count += 1
99
92
  return BaseTrack._count
100
93
 
101
- def activate(self, *args):
102
- """Activates the track with provided arguments, initializing necessary attributes for tracking."""
94
+ def activate(self, *args: Any) -> None:
95
+ """Activate the track with provided arguments, initializing necessary attributes for tracking."""
103
96
  raise NotImplementedError
104
97
 
105
- def predict(self):
106
- """Predicts the next state of the track based on the current state and tracking model."""
98
+ def predict(self) -> None:
99
+ """Predict the next state of the track based on the current state and tracking model."""
107
100
  raise NotImplementedError
108
101
 
109
- def update(self, *args, **kwargs):
110
- """Updates the track with new observations and data, modifying its state and attributes accordingly."""
102
+ def update(self, *args: Any, **kwargs: Any) -> None:
103
+ """Update the track with new observations and data, modifying its state and attributes accordingly."""
111
104
  raise NotImplementedError
112
105
 
113
- def mark_lost(self):
114
- """Marks the track as lost by updating its state to TrackState.Lost."""
106
+ def mark_lost(self) -> None:
107
+ """Mark the track as lost by updating its state to TrackState.Lost."""
115
108
  self.state = TrackState.Lost
116
109
 
117
- def mark_removed(self):
118
- """Marks the track as removed by setting its state to TrackState.Removed."""
110
+ def mark_removed(self) -> None:
111
+ """Mark the track as removed by setting its state to TrackState.Removed."""
119
112
  self.state = TrackState.Removed
120
113
 
121
114
  @staticmethod
122
- def reset_id():
115
+ def reset_id() -> None:
123
116
  """Reset the global track ID counter to its initial value."""
124
117
  BaseTrack._count = 0
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from collections import deque
4
+ from typing import Any, List, Optional
4
5
 
5
6
  import numpy as np
6
7
  import torch
@@ -51,7 +52,9 @@ class BOTrack(STrack):
51
52
 
52
53
  shared_kalman = KalmanFilterXYWH()
53
54
 
54
- def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
55
+ def __init__(
56
+ self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50
57
+ ):
55
58
  """
56
59
  Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
57
60
 
@@ -59,7 +62,7 @@ class BOTrack(STrack):
59
62
  tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
60
63
  score (float): Confidence score of the detection.
61
64
  cls (int): Class ID of the detected object.
62
- feat (np.ndarray | None): Feature vector associated with the detection.
65
+ feat (np.ndarray, optional): Feature vector associated with the detection.
63
66
  feat_history (int): Maximum length of the feature history deque.
64
67
 
65
68
  Examples:
@@ -79,7 +82,7 @@ class BOTrack(STrack):
79
82
  self.features = deque([], maxlen=feat_history)
80
83
  self.alpha = 0.9
81
84
 
82
- def update_features(self, feat):
85
+ def update_features(self, feat: np.ndarray) -> None:
83
86
  """Update the feature vector and apply exponential moving average smoothing."""
84
87
  feat /= np.linalg.norm(feat)
85
88
  self.curr_feat = feat
@@ -90,7 +93,7 @@ class BOTrack(STrack):
90
93
  self.features.append(feat)
91
94
  self.smooth_feat /= np.linalg.norm(self.smooth_feat)
92
95
 
93
- def predict(self):
96
+ def predict(self) -> None:
94
97
  """Predict the object's future state using the Kalman filter to update its mean and covariance."""
95
98
  mean_state = self.mean.copy()
96
99
  if self.state != TrackState.Tracked:
@@ -99,20 +102,20 @@ class BOTrack(STrack):
99
102
 
100
103
  self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
101
104
 
102
- def re_activate(self, new_track, frame_id, new_id=False):
105
+ def re_activate(self, new_track: "BOTrack", frame_id: int, new_id: bool = False) -> None:
103
106
  """Reactivate a track with updated features and optionally assign a new ID."""
104
107
  if new_track.curr_feat is not None:
105
108
  self.update_features(new_track.curr_feat)
106
109
  super().re_activate(new_track, frame_id, new_id)
107
110
 
108
- def update(self, new_track, frame_id):
111
+ def update(self, new_track: "BOTrack", frame_id: int) -> None:
109
112
  """Update the track with new detection information and the current frame ID."""
110
113
  if new_track.curr_feat is not None:
111
114
  self.update_features(new_track.curr_feat)
112
115
  super().update(new_track, frame_id)
113
116
 
114
117
  @property
115
- def tlwh(self):
118
+ def tlwh(self) -> np.ndarray:
116
119
  """Return the current bounding box position in `(top left x, top left y, width, height)` format."""
117
120
  if self.mean is None:
118
121
  return self._tlwh.copy()
@@ -121,7 +124,7 @@ class BOTrack(STrack):
121
124
  return ret
122
125
 
123
126
  @staticmethod
124
- def multi_predict(stracks):
127
+ def multi_predict(stracks: List["BOTrack"]) -> None:
125
128
  """Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
126
129
  if len(stracks) <= 0:
127
130
  return
@@ -136,12 +139,12 @@ class BOTrack(STrack):
136
139
  stracks[i].mean = mean
137
140
  stracks[i].covariance = cov
138
141
 
139
- def convert_coords(self, tlwh):
142
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
140
143
  """Convert tlwh bounding box coordinates to xywh format."""
141
144
  return self.tlwh_to_xywh(tlwh)
142
145
 
143
146
  @staticmethod
144
- def tlwh_to_xywh(tlwh):
147
+ def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
145
148
  """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
146
149
  ret = np.asarray(tlwh).copy()
147
150
  ret[:2] += ret[2:] / 2
@@ -176,12 +179,12 @@ class BOTSORT(BYTETracker):
176
179
  The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
177
180
  """
178
181
 
179
- def __init__(self, args, frame_rate=30):
182
+ def __init__(self, args: Any, frame_rate: int = 30):
180
183
  """
181
184
  Initialize BOTSORT object with ReID module and GMC algorithm.
182
185
 
183
186
  Args:
184
- args (object): Parsed command-line arguments containing tracking parameters.
187
+ args (Any): Parsed command-line arguments containing tracking parameters.
185
188
  frame_rate (int): Frame rate of the video being processed.
186
189
 
187
190
  Examples:
@@ -203,11 +206,13 @@ class BOTSORT(BYTETracker):
203
206
  else None
204
207
  )
205
208
 
206
- def get_kalmanfilter(self):
209
+ def get_kalmanfilter(self) -> KalmanFilterXYWH:
207
210
  """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
208
211
  return KalmanFilterXYWH()
209
212
 
210
- def init_track(self, dets, scores, cls, img=None):
213
+ def init_track(
214
+ self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None
215
+ ) -> List[BOTrack]:
211
216
  """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
212
217
  if len(dets) == 0:
213
218
  return []
@@ -217,7 +222,7 @@ class BOTSORT(BYTETracker):
217
222
  else:
218
223
  return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
219
224
 
220
- def get_dists(self, tracks, detections):
225
+ def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:
221
226
  """Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
222
227
  dists = matching.iou_distance(tracks, detections)
223
228
  dists_mask = dists > (1 - self.proximity_thresh)
@@ -232,11 +237,11 @@ class BOTSORT(BYTETracker):
232
237
  dists = np.minimum(dists, emb_dists)
233
238
  return dists
234
239
 
235
- def multi_predict(self, tracks):
240
+ def multi_predict(self, tracks: List[BOTrack]) -> None:
236
241
  """Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
237
242
  BOTrack.multi_predict(tracks)
238
243
 
239
- def reset(self):
244
+ def reset(self) -> None:
240
245
  """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
241
246
  super().reset()
242
247
  self.gmc.reset_params()
@@ -245,14 +250,19 @@ class BOTSORT(BYTETracker):
245
250
  class ReID:
246
251
  """YOLO model as encoder for re-identification."""
247
252
 
248
- def __init__(self, model):
249
- """Initialize encoder for re-identification."""
253
+ def __init__(self, model: str):
254
+ """
255
+ Initialize encoder for re-identification.
256
+
257
+ Args:
258
+ model (str): Path to the YOLO model for re-identification.
259
+ """
250
260
  from ultralytics import YOLO
251
261
 
252
262
  self.model = YOLO(model)
253
263
  self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
254
264
 
255
- def __call__(self, img, dets):
265
+ def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
256
266
  """Extract embeddings for detected objects."""
257
267
  feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
258
268
  if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]: