dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -2,61 +2,64 @@
2
2
 
3
3
  from collections import deque
4
4
  from math import sqrt
5
+ from typing import Any
5
6
 
6
7
  from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
7
8
  from ultralytics.utils.plotting import colors
8
9
 
9
10
 
10
11
  class SpeedEstimator(BaseSolution):
11
- """
12
- A class to estimate the speed of objects in a real-time video stream based on their tracks.
12
+ """A class to estimate the speed of objects in a real-time video stream based on their tracks.
13
13
 
14
- This class extends the BaseSolution class and provides functionality for estimating object speeds using
15
- tracking data in video streams.
14
+ This class extends the BaseSolution class and provides functionality for estimating object speeds using tracking
15
+ data in video streams. Speed is calculated based on pixel displacement over time and converted to real-world units
16
+ using a configurable meters-per-pixel scale factor.
16
17
 
17
18
  Attributes:
18
- spd (Dict[int, float]): Dictionary storing speed data for tracked objects.
19
- trk_hist (Dict[int, float]): Dictionary storing the object tracking data.
20
- max_hist (int): maximum track history before computing speed
21
- meters_per_pixel (float): Real-world meters represented by one pixel (e.g., 0.04 for 4m over 100px).
22
- max_speed (int): Maximum allowed object speed; values above this will be capped at 120 km/h.
19
+ fps (float): Video frame rate for time calculations.
20
+ frame_count (int): Global frame counter for tracking temporal information.
21
+ trk_frame_ids (dict): Maps track IDs to their first frame index.
22
+ spd (dict): Final speed per object in km/h once locked.
23
+ trk_hist (dict): Maps track IDs to deque of position history.
24
+ locked_ids (set): Track IDs whose speed has been finalized.
25
+ max_hist (int): Required frame history before computing speed.
26
+ meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
27
+ max_speed (int): Maximum allowed object speed; values above this will be capped.
23
28
 
24
29
  Methods:
25
- initialize_region: Initializes the speed estimation region.
26
- process: Processes input frames to estimate object speeds.
27
- store_tracking_history: Stores the tracking history for an object.
28
- extract_tracks: Extracts tracks from the current frame.
29
- display_output: Displays the output with annotations.
30
+ process: Process input frames to estimate object speeds based on tracking data.
31
+ store_tracking_history: Store the tracking history for an object.
32
+ extract_tracks: Extract tracks from the current frame.
33
+ display_output: Display the output with annotations.
30
34
 
31
35
  Examples:
32
- >>> estimator = SpeedEstimator()
36
+ Initialize speed estimator and process a frame
37
+ >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
33
38
  >>> frame = cv2.imread("frame.jpg")
34
39
  >>> results = estimator.process(frame)
35
40
  >>> cv2.imshow("Speed Estimation", results.plot_im)
36
41
  """
37
42
 
38
- def __init__(self, **kwargs):
39
- """
40
- Initialize the SpeedEstimator object with speed estimation parameters and data structures.
43
+ def __init__(self, **kwargs: Any) -> None:
44
+ """Initialize the SpeedEstimator object with speed estimation parameters and data structures.
41
45
 
42
46
  Args:
43
47
  **kwargs (Any): Additional keyword arguments passed to the parent class.
44
48
  """
45
49
  super().__init__(**kwargs)
46
50
 
47
- self.fps = self.CFG["fps"] # assumed video FPS
48
- self.frame_count = 0 # global frame count
51
+ self.fps = self.CFG["fps"] # Video frame rate for time calculations
52
+ self.frame_count = 0 # Global frame counter
49
53
  self.trk_frame_ids = {} # Track ID → first frame index
50
54
  self.spd = {} # Final speed per object (km/h), once locked
51
55
  self.trk_hist = {} # Track ID → deque of (time, position)
52
56
  self.locked_ids = set() # Track IDs whose speed has been finalized
53
57
  self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
54
58
  self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
55
- self.max_speed = self.CFG["max_speed"] # max_speed adjustment
59
+ self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
56
60
 
57
- def process(self, im0):
58
- """
59
- Process an input frame to estimate object speeds based on tracking data.
61
+ def process(self, im0) -> SolutionResults:
62
+ """Process an input frame to estimate object speeds based on tracking data.
60
63
 
61
64
  Args:
62
65
  im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.
@@ -65,6 +68,7 @@ class SpeedEstimator(BaseSolution):
65
68
  (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
66
69
 
67
70
  Examples:
71
+ Process a frame for speed estimation
68
72
  >>> estimator = SpeedEstimator()
69
73
  >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
70
74
  >>> results = estimator.process(image)
@@ -89,15 +93,15 @@ class SpeedEstimator(BaseSolution):
89
93
  p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
90
94
  dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
91
95
  if dt > 0:
92
- dx, dy = p1[0] - p0[0], p1[1] - p0[1] # pixel displacement
93
- pixel_distance = sqrt(dx * dx + dy * dy) # get pixel distance
94
- meters = pixel_distance * self.meter_per_pixel # convert to meters
96
+ dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
97
+ pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
98
+ meters = pixel_distance * self.meter_per_pixel # Convert to meters
95
99
  self.spd[track_id] = int(
96
100
  min((meters / dt) * 3.6, self.max_speed)
97
- ) # convert to km/h and store final speed
98
- self.locked_ids.add(track_id) # prevent further updates
99
- self.trk_hist.pop(track_id, None) # free memory
100
- self.trk_frame_ids.pop(track_id, None) # optional: remove frame start too
101
+ ) # Convert to km/h and store final speed
102
+ self.locked_ids.add(track_id) # Prevent further updates
103
+ self.trk_hist.pop(track_id, None) # Free memory
104
+ self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
101
105
 
102
106
  if track_id in self.spd:
103
107
  speed_label = f"{self.spd[track_id]} km/h"
@@ -1,19 +1,22 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import io
4
+ import os
4
5
  from typing import Any
5
6
 
6
7
  import cv2
8
+ import torch
7
9
 
8
10
  from ultralytics import YOLO
9
11
  from ultralytics.utils import LOGGER
10
12
  from ultralytics.utils.checks import check_requirements
11
13
  from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
12
14
 
15
+ torch.classes.__path__ = [] # Torch module __path__._path issue: https://github.com/datalab-to/marker/issues/442
16
+
13
17
 
14
18
  class Inference:
15
- """
16
- A class to perform object detection, image classification, image segmentation and pose estimation inference.
19
+ """A class to perform object detection, image classification, image segmentation and pose estimation inference.
17
20
 
18
21
  This class provides functionalities for loading models, configuring settings, uploading video files, and performing
19
22
  real-time inference using Streamlit and Ultralytics YOLO models.
@@ -24,29 +27,33 @@ class Inference:
24
27
  model_path (str): Path to the loaded model.
25
28
  model (YOLO): The YOLO model instance.
26
29
  source (str): Selected video source (webcam or video file).
27
- enable_trk (str): Enable tracking option ("Yes" or "No").
30
+ enable_trk (bool): Enable tracking option.
28
31
  conf (float): Confidence threshold for detection.
29
32
  iou (float): IoU threshold for non-maximum suppression.
30
33
  org_frame (Any): Container for the original frame to be displayed.
31
34
  ann_frame (Any): Container for the annotated frame to be displayed.
32
35
  vid_file_name (str | int): Name of the uploaded video file or webcam index.
33
- selected_ind (List[int]): List of selected class indices for detection.
36
+ selected_ind (list[int]): List of selected class indices for detection.
34
37
 
35
38
  Methods:
36
- web_ui: Sets up the Streamlit web interface with custom HTML elements.
37
- sidebar: Configures the Streamlit sidebar for model and inference settings.
38
- source_upload: Handles video file uploads through the Streamlit interface.
39
- configure: Configures the model and loads selected classes for inference.
40
- inference: Performs real-time object detection inference.
39
+ web_ui: Set up the Streamlit web interface with custom HTML elements.
40
+ sidebar: Configure the Streamlit sidebar for model and inference settings.
41
+ source_upload: Handle video file uploads through the Streamlit interface.
42
+ configure: Configure the model and load selected classes for inference.
43
+ inference: Perform real-time object detection inference.
41
44
 
42
45
  Examples:
43
- >>> inf = Inference(model="path/to/model.pt") # Model is an optional argument
46
+ Create an Inference instance with a custom model
47
+ >>> inf = Inference(model="path/to/model.pt")
48
+ >>> inf.inference()
49
+
50
+ Create an Inference instance with default settings
51
+ >>> inf = Inference()
44
52
  >>> inf.inference()
45
53
  """
46
54
 
47
- def __init__(self, **kwargs: Any):
48
- """
49
- Initialize the Inference class, checking Streamlit requirements and setting up the model path.
55
+ def __init__(self, **kwargs: Any) -> None:
56
+ """Initialize the Inference class, checking Streamlit requirements and setting up the model path.
50
57
 
51
58
  Args:
52
59
  **kwargs (Any): Additional keyword arguments for model configuration.
@@ -56,13 +63,14 @@ class Inference:
56
63
 
57
64
  self.st = st # Reference to the Streamlit module
58
65
  self.source = None # Video source selection (webcam or video file)
66
+ self.img_file_names = [] # List of image file names
59
67
  self.enable_trk = False # Flag to toggle object tracking
60
68
  self.conf = 0.25 # Confidence threshold for detection
61
69
  self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
62
70
  self.org_frame = None # Container for the original frame display
63
71
  self.ann_frame = None # Container for the annotated frame display
64
72
  self.vid_file_name = None # Video file name or webcam index
65
- self.selected_ind = [] # List of selected class indices for detection
73
+ self.selected_ind: list[int] = [] # List of selected class indices for detection
66
74
  self.model = None # YOLO model instance
67
75
 
68
76
  self.temp_dict = {"model": None, **kwargs}
@@ -72,18 +80,18 @@ class Inference:
72
80
 
73
81
  LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
74
82
 
75
- def web_ui(self):
76
- """Sets up the Streamlit web interface with custom HTML elements."""
83
+ def web_ui(self) -> None:
84
+ """Set up the Streamlit web interface with custom HTML elements."""
77
85
  menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
78
86
 
79
87
  # Main title of streamlit application
80
- main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
88
+ main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
81
89
  font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
82
90
 
83
91
  # Subtitle of streamlit application
84
- sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
85
- margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
86
- of Ultralytics YOLO! 🚀</h4></div>"""
92
+ sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
93
+ margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
94
+ with the power of Ultralytics YOLO! 🚀</h5></div>"""
87
95
 
88
96
  # Set html page configuration and append custom HTML
89
97
  self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
@@ -91,7 +99,7 @@ class Inference:
91
99
  self.st.markdown(main_title_cfg, unsafe_allow_html=True)
92
100
  self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
93
101
 
94
- def sidebar(self):
102
+ def sidebar(self) -> None:
95
103
  """Configure the Streamlit sidebar for model and inference settings."""
96
104
  with self.st.sidebar: # Add Ultralytics LOGO
97
105
  logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
@@ -99,24 +107,28 @@ class Inference:
99
107
 
100
108
  self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
101
109
  self.source = self.st.sidebar.selectbox(
102
- "Video",
103
- ("webcam", "video"),
110
+ "Source",
111
+ ("webcam", "video", "image"),
104
112
  ) # Add source selection dropdown
105
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
113
+ if self.source in ["webcam", "video"]:
114
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
106
115
  self.conf = float(
107
116
  self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
108
117
  ) # Slider for confidence
109
118
  self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
110
119
 
111
- col1, col2 = self.st.columns(2) # Create two columns for displaying frames
112
- self.org_frame = col1.empty() # Container for original frame
113
- self.ann_frame = col2.empty() # Container for annotated frame
120
+ if self.source != "image": # Only create columns for video/webcam
121
+ col1, col2 = self.st.columns(2) # Create two columns for displaying frames
122
+ self.org_frame = col1.empty() # Container for original frame
123
+ self.ann_frame = col2.empty() # Container for annotated frame
114
124
 
115
- def source_upload(self):
125
+ def source_upload(self) -> None:
116
126
  """Handle video file uploads through the Streamlit interface."""
127
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
128
+
117
129
  self.vid_file_name = ""
118
130
  if self.source == "video":
119
- vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
131
+ vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
120
132
  if vid_file is not None:
121
133
  g = io.BytesIO(vid_file.read()) # BytesIO Object
122
134
  with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
@@ -124,17 +136,41 @@ class Inference:
124
136
  self.vid_file_name = "ultralytics.mp4"
125
137
  elif self.source == "webcam":
126
138
  self.vid_file_name = 0 # Use webcam index 0
127
-
128
- def configure(self):
139
+ elif self.source == "image":
140
+ import tempfile # scope import
141
+
142
+ if imgfiles := self.st.sidebar.file_uploader(
143
+ "Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
144
+ ):
145
+ for imgfile in imgfiles: # Save each uploaded image to a temporary file
146
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
147
+ tf.write(imgfile.read())
148
+ self.img_file_names.append({"path": tf.name, "name": imgfile.name})
149
+
150
+ def configure(self) -> None:
129
151
  """Configure the model and load selected classes for inference."""
130
152
  # Add dropdown menu for model selection
131
- available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
132
- if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
133
- available_models.insert(0, self.model_path.split(".pt")[0])
153
+ M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"]
154
+ available_models = sorted(
155
+ [
156
+ x.replace("yolo", "YOLO")
157
+ for x in GITHUB_ASSETS_STEMS
158
+ if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x
159
+ ],
160
+ key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
161
+ )
162
+ if self.model_path: # Insert user provided custom model in available_models
163
+ available_models.insert(0, self.model_path)
134
164
  selected_model = self.st.sidebar.selectbox("Model", available_models)
135
165
 
136
166
  with self.st.spinner("Model is downloading..."):
137
- self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
167
+ if selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine")) or any(
168
+ fmt in selected_model for fmt in ("openvino_model", "rknn_model")
169
+ ):
170
+ model_path = selected_model
171
+ else:
172
+ model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
173
+ self.model = YOLO(model_path) # Load the YOLO model
138
174
  class_names = list(self.model.names.values()) # Convert dictionary to list of class names
139
175
  self.st.success("Model loaded successfully!")
140
176
 
@@ -145,7 +181,28 @@ class Inference:
145
181
  if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
146
182
  self.selected_ind = list(self.selected_ind)
147
183
 
148
- def inference(self):
184
+ def image_inference(self) -> None:
185
+ """Perform inference on uploaded images."""
186
+ for img_info in self.img_file_names:
187
+ img_path = img_info["path"]
188
+ image = cv2.imread(img_path) # Load and display the original image
189
+ if image is not None:
190
+ self.st.markdown(f"#### Processed: {img_info['name']}")
191
+ col1, col2 = self.st.columns(2)
192
+ with col1:
193
+ self.st.image(image, channels="BGR", caption="Original Image")
194
+ results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
195
+ annotated_image = results[0].plot()
196
+ with col2:
197
+ self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
198
+ try: # Clean up temporary file
199
+ os.unlink(img_path)
200
+ except FileNotFoundError:
201
+ pass # File doesn't exist, ignore
202
+ else:
203
+ self.st.error("Could not load the uploaded image.")
204
+
205
+ def inference(self) -> None:
149
206
  """Perform real-time object detection inference on video or webcam feed."""
150
207
  self.web_ui() # Initialize the web interface
151
208
  self.sidebar() # Create the sidebar
@@ -153,7 +210,14 @@ class Inference:
153
210
  self.configure() # Configure the app
154
211
 
155
212
  if self.st.sidebar.button("Start"):
156
- stop_button = self.st.button("Stop") # Button to stop the inference
213
+ if self.source == "image":
214
+ if self.img_file_names:
215
+ self.image_inference()
216
+ else:
217
+ self.st.info("Please upload an image file to perform inference.")
218
+ return
219
+
220
+ stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
157
221
  cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
158
222
  if not cap.isOpened():
159
223
  self.st.error("Could not open webcam or video source.")
@@ -166,7 +230,7 @@ class Inference:
166
230
  break
167
231
 
168
232
  # Process frame with model
169
- if self.enable_trk == "Yes":
233
+ if self.enable_trk:
170
234
  results = self.model.track(
171
235
  frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
172
236
  )
@@ -179,8 +243,8 @@ class Inference:
179
243
  cap.release() # Release the capture
180
244
  self.st.stop() # Stop streamlit app
181
245
 
182
- self.org_frame.image(frame, channels="BGR") # Display original frame
183
- self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
246
+ self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
247
+ self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
184
248
 
185
249
  cap.release() # Release the capture
186
250
  cv2.destroyAllWindows() # Destroy all OpenCV windows
@@ -35,7 +35,6 @@
35
35
  align-items: center;
36
36
  gap: 1rem;
37
37
  margin-bottom: 3rem;
38
- animation: fadeIn 1s ease-in-out;
39
38
  }
40
39
 
41
40
  input[type="text"] {
@@ -78,7 +77,6 @@
78
77
  gap: 1.5rem;
79
78
  max-width: 1600px;
80
79
  margin: auto;
81
- animation: fadeInUp 1s ease-in-out;
82
80
  }
83
81
 
84
82
  .card {
@@ -102,30 +100,22 @@
102
100
  object-fit: cover;
103
101
  display: block;
104
102
  }
105
-
106
- @keyframes fadeIn {
107
- 0% {
108
- opacity: 0;
109
- transform: scale(0.95);
110
- }
111
- 100% {
112
- opacity: 1;
113
- transform: scale(1);
114
- }
115
- }
116
-
117
- @keyframes fadeInUp {
118
- 0% {
119
- opacity: 0;
120
- transform: translateY(20px);
121
- }
122
- 100% {
123
- opacity: 1;
124
- transform: translateY(0);
125
- }
126
- }
127
103
  </style>
128
104
  </head>
105
+ <script>
106
+ function filterResults(k) {
107
+ const cards = document.querySelectorAll(".grid .card");
108
+ cards.forEach((card, idx) => {
109
+ card.style.display = idx < k ? "block" : "none";
110
+ });
111
+ const buttons = document.querySelectorAll(".topk-btn");
112
+ buttons.forEach((btn) => btn.classList.remove("active"));
113
+ event.target.classList.add("active");
114
+ }
115
+ document.addEventListener("DOMContentLoaded", () => {
116
+ filterResults(10);
117
+ });
118
+ </script>
129
119
  <body>
130
120
  <div style="text-align: center; margin-bottom: 1rem">
131
121
  <img
@@ -146,6 +136,23 @@
146
136
  required
147
137
  />
148
138
  <button type="submit">Search</button>
139
+ {% if results %}
140
+ <div class="top-k-buttons">
141
+ <button type="button" class="topk-btn" onclick="filterResults(5)">
142
+ Top 5
143
+ </button>
144
+ <button
145
+ type="button"
146
+ class="topk-btn active"
147
+ onclick="filterResults(10)"
148
+ >
149
+ Top 10
150
+ </button>
151
+ <button type="button" class="topk-btn" onclick="filterResults(30)">
152
+ Top 30
153
+ </button>
154
+ </div>
155
+ {% endif %}
149
156
  </form>
150
157
 
151
158
  <!-- Search results grid -->
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any
4
+
3
5
  import cv2
4
6
  import numpy as np
5
7
 
@@ -8,8 +10,7 @@ from ultralytics.utils.plotting import colors
8
10
 
9
11
 
10
12
  class TrackZone(BaseSolution):
11
- """
12
- A class to manage region-based object tracking in a video stream.
13
+ """A class to manage region-based object tracking in a video stream.
13
14
 
14
15
  This class extends the BaseSolution class and provides functionality for tracking objects within a specific region
15
16
  defined by a polygonal area. Objects outside the region are excluded from tracking.
@@ -17,15 +18,15 @@ class TrackZone(BaseSolution):
17
18
  Attributes:
18
19
  region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.
19
20
  line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.
20
- names (List[str]): List of class names that the model can detect.
21
- boxes (List[np.ndarray]): Bounding boxes of tracked objects.
22
- track_ids (List[int]): Unique identifiers for each tracked object.
23
- clss (List[int]): Class indices of tracked objects.
21
+ names (list[str]): List of class names that the model can detect.
22
+ boxes (list[np.ndarray]): Bounding boxes of tracked objects.
23
+ track_ids (list[int]): Unique identifiers for each tracked object.
24
+ clss (list[int]): Class indices of tracked objects.
24
25
 
25
26
  Methods:
26
- process: Processes each frame of the video, applying region-based tracking.
27
- extract_tracks: Extracts tracking information from the input frame.
28
- display_output: Displays the processed output.
27
+ process: Process each frame of the video, applying region-based tracking.
28
+ extract_tracks: Extract tracking information from the input frame.
29
+ display_output: Display the processed output.
29
30
 
30
31
  Examples:
31
32
  >>> tracker = TrackZone()
@@ -34,9 +35,8 @@ class TrackZone(BaseSolution):
34
35
  >>> cv2.imshow("Tracked Frame", results.plot_im)
35
36
  """
36
37
 
37
- def __init__(self, **kwargs):
38
- """
39
- Initialize the TrackZone class for tracking objects within a defined region in video streams.
38
+ def __init__(self, **kwargs: Any) -> None:
39
+ """Initialize the TrackZone class for tracking objects within a defined region in video streams.
40
40
 
41
41
  Args:
42
42
  **kwargs (Any): Additional keyword arguments passed to the parent class.
@@ -44,20 +44,20 @@ class TrackZone(BaseSolution):
44
44
  super().__init__(**kwargs)
45
45
  default_region = [(75, 75), (565, 75), (565, 285), (75, 285)]
46
46
  self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32))
47
+ self.mask = None
47
48
 
48
- def process(self, im0):
49
- """
50
- Process the input frame to track objects within a defined region.
49
+ def process(self, im0: np.ndarray) -> SolutionResults:
50
+ """Process the input frame to track objects within a defined region.
51
51
 
52
- This method initializes the annotator, creates a mask for the specified region, extracts tracks
53
- only from the masked area, and updates tracking information. Objects outside the region are ignored.
52
+ This method initializes the annotator, creates a mask for the specified region, extracts tracks only from the
53
+ masked area, and updates tracking information. Objects outside the region are ignored.
54
54
 
55
55
  Args:
56
56
  im0 (np.ndarray): The input image or frame to be processed.
57
57
 
58
58
  Returns:
59
- (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the
60
- total number of tracked objects within the defined region.
59
+ (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the total number
60
+ of tracked objects within the defined region.
61
61
 
62
62
  Examples:
63
63
  >>> tracker = TrackZone()
@@ -66,10 +66,10 @@ class TrackZone(BaseSolution):
66
66
  """
67
67
  annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
68
68
 
69
- # Create a mask for the region and extract tracks from the masked image
70
- mask = np.zeros_like(im0[:, :, 0])
71
- mask = cv2.fillPoly(mask, [self.region], 255)
72
- masked_frame = cv2.bitwise_and(im0, im0, mask=mask)
69
+ if self.mask is None: # Create a mask for the region
70
+ self.mask = np.zeros_like(im0[:, :, 0])
71
+ cv2.fillPoly(self.mask, [self.region], 255)
72
+ masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask)
73
73
  self.extract_tracks(masked_frame)
74
74
 
75
75
  # Draw the region boundary
@@ -82,7 +82,7 @@ class TrackZone(BaseSolution):
82
82
  )
83
83
 
84
84
  plot_im = annotator.result()
85
- self.display_output(plot_im) # display output with base class function
85
+ self.display_output(plot_im) # Display output with base class function
86
86
 
87
87
  # Return a SolutionResults
88
88
  return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
@@ -1,18 +1,19 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any
4
+
3
5
  from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
4
6
  from ultralytics.utils.plotting import colors
5
7
 
6
8
 
7
9
  class VisionEye(BaseSolution):
8
- """
9
- A class to manage object detection and vision mapping in images or video streams.
10
+ """A class to manage object detection and vision mapping in images or video streams.
10
11
 
11
- This class extends the BaseSolution class and provides functionality for detecting objects,
12
- mapping vision points, and annotating results with bounding boxes and labels.
12
+ This class extends the BaseSolution class and provides functionality for detecting objects, mapping vision points,
13
+ and annotating results with bounding boxes and labels.
13
14
 
14
15
  Attributes:
15
- vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.
16
+ vision_point (tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.
16
17
 
17
18
  Methods:
18
19
  process: Process the input image to detect objects, annotate them, and apply vision mapping.
@@ -24,9 +25,8 @@ class VisionEye(BaseSolution):
24
25
  >>> print(f"Total detected instances: {results.total_tracks}")
25
26
  """
26
27
 
27
- def __init__(self, **kwargs):
28
- """
29
- Initialize the VisionEye class for detecting objects and applying vision mapping.
28
+ def __init__(self, **kwargs: Any) -> None:
29
+ """Initialize the VisionEye class for detecting objects and applying vision mapping.
30
30
 
31
31
  Args:
32
32
  **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point.
@@ -35,12 +35,11 @@ class VisionEye(BaseSolution):
35
35
  # Set the vision point where the system will view objects and draw tracks
36
36
  self.vision_point = self.CFG["vision_point"]
37
37
 
38
- def process(self, im0):
39
- """
40
- Perform object detection, vision mapping, and annotation on the input image.
38
+ def process(self, im0) -> SolutionResults:
39
+ """Perform object detection, vision mapping, and annotation on the input image.
41
40
 
42
41
  Args:
43
- im0 (numpy.ndarray): The input image for detection and annotation.
42
+ im0 (np.ndarray): The input image for detection and annotation.
44
43
 
45
44
  Returns:
46
45
  (SolutionResults): Object containing the annotated image and tracking statistics.
@@ -4,4 +4,4 @@ from .bot_sort import BOTSORT
4
4
  from .byte_tracker import BYTETracker
5
5
  from .track import register_tracker
6
6
 
7
- __all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import
7
+ __all__ = "BOTSORT", "BYTETracker", "register_tracker" # allow simpler import