ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from collections import defaultdict
4
+
5
+ import cv2
6
+
7
+ from ultralytics import YOLO
8
+ from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER
9
+ from ultralytics.utils.checks import check_imshow, check_requirements
10
+
11
+
12
+ class BaseSolution:
13
+ """
14
+ A base class for managing Ultralytics Solutions.
15
+
16
+ This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
17
+ and region initialization.
18
+
19
+ Attributes:
20
+ LineString (shapely.geometry.LineString): Class for creating line string geometries.
21
+ Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
22
+ Point (shapely.geometry.Point): Class for creating point geometries.
23
+ CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
24
+ region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
25
+ line_width (int): Width of lines used in visualizations.
26
+ model (ultralytics.YOLO): Loaded YOLO model instance.
27
+ names (Dict[int, str]): Dictionary mapping class indices to class names.
28
+ env_check (bool): Flag indicating whether the environment supports image display.
29
+ track_history (collections.defaultdict): Dictionary to store tracking history for each object.
30
+
31
+ Methods:
32
+ extract_tracks: Apply object tracking and extract tracks from an input image.
33
+ store_tracking_history: Store object tracking history for a given track ID and bounding box.
34
+ initialize_region: Initialize the counting region and line segment based on configuration.
35
+ display_output: Display the results of processing, including showing frames or saving results.
36
+
37
+ Examples:
38
+ >>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
39
+ >>> solution.initialize_region()
40
+ >>> image = cv2.imread("image.jpg")
41
+ >>> solution.extract_tracks(image)
42
+ >>> solution.display_output(image)
43
+ """
44
+
45
+ def __init__(self, IS_CLI=False, **kwargs):
46
+ """
47
+ Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions.
48
+
49
+ IS_CLI (optional): Enables CLI mode if set.
50
+ """
51
+ check_requirements("shapely>=2.0.0")
52
+ from shapely.geometry import LineString, Point, Polygon
53
+ from shapely.prepared import prep
54
+
55
+ self.LineString = LineString
56
+ self.Polygon = Polygon
57
+ self.Point = Point
58
+ self.prep = prep
59
+ self.annotator = None # Initialize annotator
60
+ self.tracks = None
61
+ self.track_data = None
62
+ self.boxes = []
63
+ self.clss = []
64
+ self.track_ids = []
65
+ self.track_line = None
66
+ self.r_s = None
67
+
68
+ # Load config and update with args
69
+ DEFAULT_SOL_DICT.update(kwargs)
70
+ DEFAULT_CFG_DICT.update(kwargs)
71
+ self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT}
72
+ LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}")
73
+
74
+ self.region = self.CFG["region"] # Store region data for other classes usage
75
+ self.line_width = (
76
+ self.CFG["line_width"] if self.CFG["line_width"] is not None else 2
77
+ ) # Store line_width for usage
78
+
79
+ # Load Model and store classes names
80
+ if self.CFG["model"] is None:
81
+ self.CFG["model"] = "yolo11n.pt"
82
+ self.model = YOLO(self.CFG["model"])
83
+ self.names = self.model.names
84
+
85
+ self.track_add_args = { # Tracker additional arguments for advance configuration
86
+ k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
87
+ }
88
+
89
+ if IS_CLI and self.CFG["source"] is None:
90
+ d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
91
+ LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
92
+ from ultralytics.utils.downloads import safe_download
93
+
94
+ safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
95
+ self.CFG["source"] = d_s # set default source
96
+
97
+ # Initialize environment and region setup
98
+ self.env_check = check_imshow(warn=True)
99
+ self.track_history = defaultdict(list)
100
+
101
+ def extract_tracks(self, im0):
102
+ """
103
+ Applies object tracking and extracts tracks from an input image or frame.
104
+
105
+ Args:
106
+ im0 (ndarray): The input image or frame.
107
+
108
+ Examples:
109
+ >>> solution = BaseSolution()
110
+ >>> frame = cv2.imread("path/to/image.jpg")
111
+ >>> solution.extract_tracks(frame)
112
+ """
113
+ self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
114
+
115
+ # Extract tracks for OBB or object detection
116
+ self.track_data = self.tracks[0].obb or self.tracks[0].boxes
117
+
118
+ if self.track_data and self.track_data.id is not None:
119
+ self.boxes = self.track_data.xyxy.cpu()
120
+ self.clss = self.track_data.cls.cpu().tolist()
121
+ self.track_ids = self.track_data.id.int().cpu().tolist()
122
+ else:
123
+ LOGGER.warning("WARNING ⚠️ no tracks found!")
124
+ self.boxes, self.clss, self.track_ids = [], [], []
125
+
126
+ def store_tracking_history(self, track_id, box):
127
+ """
128
+ Stores the tracking history of an object.
129
+
130
+ This method updates the tracking history for a given object by appending the center point of its
131
+ bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
132
+
133
+ Args:
134
+ track_id (int): The unique identifier for the tracked object.
135
+ box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
136
+
137
+ Examples:
138
+ >>> solution = BaseSolution()
139
+ >>> solution.store_tracking_history(1, [100, 200, 300, 400])
140
+ """
141
+ # Store tracking history
142
+ self.track_line = self.track_history[track_id]
143
+ self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
144
+ if len(self.track_line) > 30:
145
+ self.track_line.pop(0)
146
+
147
+ def initialize_region(self):
148
+ """Initialize the counting region and line segment based on configuration settings."""
149
+ if self.region is None:
150
+ self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
151
+ self.r_s = (
152
+ self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
153
+ ) # region or line
154
+
155
+ def display_output(self, im0):
156
+ """
157
+ Display the results of the processing, which could involve showing frames, printing counts, or saving results.
158
+
159
+ This method is responsible for visualizing the output of the object detection and tracking process. It displays
160
+ the processed frame with annotations, and allows for user interaction to close the display.
161
+
162
+ Args:
163
+ im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
164
+
165
+ Examples:
166
+ >>> solution = BaseSolution()
167
+ >>> frame = cv2.imread("path/to/image.jpg")
168
+ >>> solution.display_output(frame)
169
+
170
+ Notes:
171
+ - This method will only display output if the 'show' configuration is set to True and the environment
172
+ supports image display.
173
+ - The display can be closed by pressing the 'q' key.
174
+ """
175
+ if self.CFG.get("show") and self.env_check:
176
+ cv2.imshow("Ultralytics Solutions", im0)
177
+ if cv2.waitKey(1) & 0xFF == ord("q"):
178
+ return
@@ -1,198 +1,110 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from collections import defaultdict
4
3
  from time import time
5
4
 
6
- import cv2
7
5
  import numpy as np
8
6
 
9
- from ultralytics.utils.checks import check_imshow
7
+ from ultralytics.solutions.solutions import BaseSolution
10
8
  from ultralytics.utils.plotting import Annotator, colors
11
9
 
12
10
 
13
- class SpeedEstimator:
14
- """A class to estimation speed of objects in real-time video stream based on their tracks."""
15
-
16
- def __init__(self):
17
- """Initializes the speed-estimator class with default values for Visual, Image, track and speed parameters."""
18
-
19
- # Visual & im0 information
20
- self.im0 = None
21
- self.annotator = None
22
- self.view_img = False
23
-
24
- # Region information
25
- self.reg_pts = [(20, 400), (1260, 400)]
26
- self.region_thickness = 3
27
-
28
- # Predict/track information
29
- self.clss = None
30
- self.names = None
31
- self.boxes = None
32
- self.trk_ids = None
33
- self.trk_pts = None
34
- self.line_thickness = 2
35
- self.trk_history = defaultdict(list)
36
-
37
- # Speed estimator information
38
- self.current_time = 0
39
- self.dist_data = {}
40
- self.trk_idslist = []
41
- self.spdl_dist_thresh = 10
42
- self.trk_previous_times = {}
43
- self.trk_previous_points = {}
44
-
45
- # Check if environment support imshow
46
- self.env_check = check_imshow(warn=True)
47
-
48
- def set_args(
49
- self,
50
- reg_pts,
51
- names,
52
- view_img=False,
53
- line_thickness=2,
54
- region_thickness=5,
55
- spdl_dist_thresh=10,
56
- ):
11
+ class SpeedEstimator(BaseSolution):
12
+ """
13
+ A class to estimate the speed of objects in a real-time video stream based on their tracks.
14
+
15
+ This class extends the BaseSolution class and provides functionality for estimating object speeds using
16
+ tracking data in video streams.
17
+
18
+ Attributes:
19
+ spd (Dict[int, float]): Dictionary storing speed data for tracked objects.
20
+ trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated.
21
+ trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects.
22
+ trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects.
23
+ annotator (Annotator): Annotator object for drawing on images.
24
+ region (List[Tuple[int, int]]): List of points defining the speed estimation region.
25
+ track_line (List[Tuple[float, float]]): List of points representing the object's track.
26
+ r_s (LineString): LineString object representing the speed estimation region.
27
+
28
+ Methods:
29
+ initialize_region: Initializes the speed estimation region.
30
+ estimate_speed: Estimates the speed of objects based on tracking data.
31
+ store_tracking_history: Stores the tracking history for an object.
32
+ extract_tracks: Extracts tracks from the current frame.
33
+ display_output: Displays the output with annotations.
34
+
35
+ Examples:
36
+ >>> estimator = SpeedEstimator()
37
+ >>> frame = cv2.imread("frame.jpg")
38
+ >>> processed_frame = estimator.estimate_speed(frame)
39
+ >>> cv2.imshow("Speed Estimation", processed_frame)
40
+ """
41
+
42
+ def __init__(self, **kwargs):
43
+ """Initializes the SpeedEstimator object with speed estimation parameters and data structures."""
44
+ super().__init__(**kwargs)
45
+
46
+ self.initialize_region() # Initialize speed region
47
+
48
+ self.spd = {} # set for speed data
49
+ self.trkd_ids = [] # list for already speed_estimated and tracked ID's
50
+ self.trk_pt = {} # set for tracks previous time
51
+ self.trk_pp = {} # set for tracks previous point
52
+
53
+ def estimate_speed(self, im0):
57
54
  """
58
- Configures the speed estimation and display parameters.
55
+ Estimates the speed of objects based on tracking data.
59
56
 
60
57
  Args:
61
- reg_pts (list): Initial list of points defining the speed calculation region.
62
- names (dict): object detection classes names
63
- view_img (bool): Flag indicating frame display
64
- line_thickness (int): Line thickness for bounding boxes.
65
- region_thickness (int): Speed estimation region thickness
66
- spdl_dist_thresh (int): Euclidean distance threshold for speed line
67
- """
68
- if reg_pts is None:
69
- print("Region points not provided, using default values")
70
- else:
71
- self.reg_pts = reg_pts
72
- self.names = names
73
- self.view_img = view_img
74
- self.line_thickness = line_thickness
75
- self.region_thickness = region_thickness
76
- self.spdl_dist_thresh = spdl_dist_thresh
77
-
78
- def extract_tracks(self, tracks):
79
- """
80
- Extracts results from the provided data.
81
-
82
- Args:
83
- tracks (list): List of tracks obtained from the object tracking process.
84
- """
85
- self.boxes = tracks[0].boxes.xyxy.cpu()
86
- self.clss = tracks[0].boxes.cls.cpu().tolist()
87
- self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()
88
-
89
- def store_track_info(self, track_id, box):
90
- """
91
- Store track data.
92
-
93
- Args:
94
- track_id (int): object track id.
95
- box (list): object bounding box data
96
- """
97
- track = self.trk_history[track_id]
98
- bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))
99
- track.append(bbox_center)
100
-
101
- if len(track) > 30:
102
- track.pop(0)
103
-
104
- self.trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
105
- return track
58
+ im0 (np.ndarray): Input image for processing. Shape is typically (H, W, C) for RGB images.
106
59
 
107
- def plot_box_and_track(self, track_id, box, cls, track):
108
- """
109
- Plot track and bounding box.
60
+ Returns:
61
+ (np.ndarray): Processed image with speed estimations and annotations.
110
62
 
111
- Args:
112
- track_id (int): object track id.
113
- box (list): object bounding box data
114
- cls (str): object class name
115
- track (list): tracking history for tracks path drawing
63
+ Examples:
64
+ >>> estimator = SpeedEstimator()
65
+ >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
66
+ >>> processed_image = estimator.estimate_speed(image)
116
67
  """
117
- speed_label = f"{int(self.dist_data[track_id])}km/ph" if track_id in self.dist_data else self.names[int(cls)]
118
- bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255)
119
-
120
- self.annotator.box_label(box, speed_label, bbox_color)
121
-
122
- cv2.polylines(self.im0, [self.trk_pts], isClosed=False, color=(0, 255, 0), thickness=1)
123
- cv2.circle(self.im0, (int(track[-1][0]), int(track[-1][1])), 5, bbox_color, -1)
68
+ self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
69
+ self.extract_tracks(im0) # Extract tracks
124
70
 
125
- def calculate_speed(self, trk_id, track):
126
- """
127
- Calculation of object speed.
128
-
129
- Args:
130
- trk_id (int): object track id.
131
- track (list): tracking history for tracks path drawing
132
- """
133
-
134
- if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
135
- return
136
- if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
137
- direction = "known"
138
-
139
- elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
140
- direction = "known"
141
-
142
- else:
143
- direction = "unknown"
144
-
145
- if self.trk_previous_times[trk_id] != 0 and direction != "unknown" and trk_id not in self.trk_idslist:
146
- self.trk_idslist.append(trk_id)
147
-
148
- time_difference = time() - self.trk_previous_times[trk_id]
149
- if time_difference > 0:
150
- dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1])
151
- speed = dist_difference / time_difference
152
- self.dist_data[trk_id] = speed
153
-
154
- self.trk_previous_times[trk_id] = time()
155
- self.trk_previous_points[trk_id] = track[-1]
156
-
157
- def estimate_speed(self, im0, tracks, region_color=(255, 0, 0)):
158
- """
159
- Calculate object based on tracking data.
160
-
161
- Args:
162
- im0 (nd array): Image
163
- tracks (list): List of tracks obtained from the object tracking process.
164
- region_color (tuple): Color to use when drawing regions.
165
- """
166
- self.im0 = im0
167
- if tracks[0].boxes.id is None:
168
- if self.view_img and self.env_check:
169
- self.display_frames()
170
- return im0
171
- self.extract_tracks(tracks)
71
+ self.annotator.draw_region(
72
+ reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
73
+ ) # Draw region
172
74
 
173
- self.annotator = Annotator(self.im0, line_width=2)
174
- self.annotator.draw_region(reg_pts=self.reg_pts, color=region_color, thickness=self.region_thickness)
75
+ for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
76
+ self.store_tracking_history(track_id, box) # Store track history
175
77
 
176
- for box, trk_id, cls in zip(self.boxes, self.trk_ids, self.clss):
177
- track = self.store_track_info(trk_id, box)
78
+ # Check if track_id is already in self.trk_pp or trk_pt initialize if not
79
+ if track_id not in self.trk_pt:
80
+ self.trk_pt[track_id] = 0
81
+ if track_id not in self.trk_pp:
82
+ self.trk_pp[track_id] = self.track_line[-1]
178
83
 
179
- if trk_id not in self.trk_previous_times:
180
- self.trk_previous_times[trk_id] = 0
84
+ speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)]
85
+ self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
181
86
 
182
- self.plot_box_and_track(trk_id, box, cls, track)
183
- self.calculate_speed(trk_id, track)
87
+ # Draw tracks of objects
88
+ self.annotator.draw_centroid_and_tracks(
89
+ self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
90
+ )
184
91
 
185
- if self.view_img and self.env_check:
186
- self.display_frames()
92
+ # Calculate object speed and direction based on region intersection
93
+ if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s):
94
+ direction = "known"
95
+ else:
96
+ direction = "unknown"
187
97
 
188
- return im0
98
+ # Perform speed calculation and tracking updates if direction is valid
99
+ if direction == "known" and track_id not in self.trkd_ids:
100
+ self.trkd_ids.append(track_id)
101
+ time_difference = time() - self.trk_pt[track_id]
102
+ if time_difference > 0:
103
+ self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference
189
104
 
190
- def display_frames(self):
191
- """Display frame."""
192
- cv2.imshow("Ultralytics Speed Estimation", self.im0)
193
- if cv2.waitKey(1) & 0xFF == ord("q"):
194
- return
105
+ self.trk_pt[track_id] = time()
106
+ self.trk_pp[track_id] = self.track_line[-1]
195
107
 
108
+ self.display_output(im0) # display output with base class function
196
109
 
197
- if __name__ == "__main__":
198
- SpeedEstimator()
110
+ return im0 # return output image for more usage
@@ -0,0 +1,190 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import io
4
+ from typing import Any
5
+
6
+ import cv2
7
+
8
+ from ultralytics import YOLO
9
+ from ultralytics.utils import LOGGER
10
+ from ultralytics.utils.checks import check_requirements
11
+ from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
12
+
13
+
14
+ class Inference:
15
+ """
16
+ A class to perform object detection, image classification, image segmentation and pose estimation inference using
17
+ Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
18
+ uploading video files, and performing real-time inference.
19
+
20
+ Attributes:
21
+ st (module): Streamlit module for UI creation.
22
+ temp_dict (dict): Temporary dictionary to store the model path.
23
+ model_path (str): Path to the loaded model.
24
+ model (YOLO): The YOLO model instance.
25
+ source (str): Selected video source.
26
+ enable_trk (str): Enable tracking option.
27
+ conf (float): Confidence threshold.
28
+ iou (float): IoU threshold for non-max suppression.
29
+ vid_file_name (str): Name of the uploaded video file.
30
+ selected_ind (list): List of selected class indices.
31
+
32
+ Methods:
33
+ web_ui: Sets up the Streamlit web interface with custom HTML elements.
34
+ sidebar: Configures the Streamlit sidebar for model and inference settings.
35
+ source_upload: Handles video file uploads through the Streamlit interface.
36
+ configure: Configures the model and loads selected classes for inference.
37
+ inference: Performs real-time object detection inference.
38
+
39
+ Examples:
40
+ >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
41
+ >>> inf.inference()
42
+ """
43
+
44
+ def __init__(self, **kwargs: Any):
45
+ """
46
+ Initializes the Inference class, checking Streamlit requirements and setting up the model path.
47
+
48
+ Args:
49
+ **kwargs (Any): Additional keyword arguments for model configuration.
50
+ """
51
+ check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
52
+ import streamlit as st
53
+
54
+ self.st = st # Reference to the Streamlit class instance
55
+ self.source = None # Placeholder for video or webcam source details
56
+ self.enable_trk = False # Flag to toggle object tracking
57
+ self.conf = 0.25 # Confidence threshold for detection
58
+ self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
59
+ self.org_frame = None # Container for the original frame to be displayed
60
+ self.ann_frame = None # Container for the annotated frame to be displayed
61
+ self.vid_file_name = None # Holds the name of the video file
62
+ self.selected_ind = [] # List of selected classes for detection or tracking
63
+ self.model = None # Container for the loaded model instance
64
+
65
+ self.temp_dict = {"model": None, **kwargs}
66
+ self.model_path = None # Store model file name with path
67
+ if self.temp_dict["model"] is not None:
68
+ self.model_path = self.temp_dict["model"]
69
+
70
+ LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
71
+
72
+ def web_ui(self):
73
+ """Sets up the Streamlit web interface with custom HTML elements."""
74
+ menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
75
+
76
+ # Main title of streamlit application
77
+ main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
78
+ font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
79
+
80
+ # Subtitle of streamlit application
81
+ sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
82
+ margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
83
+ of Ultralytics YOLO! 🚀</h4></div>"""
84
+
85
+ # Set html page configuration and append custom HTML
86
+ self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
87
+ self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
88
+ self.st.markdown(main_title_cfg, unsafe_allow_html=True)
89
+ self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
90
+
91
+ def sidebar(self):
92
+ """Configures the Streamlit sidebar for model and inference settings."""
93
+ with self.st.sidebar: # Add Ultralytics LOGO
94
+ logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
95
+ self.st.image(logo, width=250)
96
+
97
+ self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
98
+ self.source = self.st.sidebar.selectbox(
99
+ "Video",
100
+ ("webcam", "video"),
101
+ ) # Add source selection dropdown
102
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
103
+ self.conf = float(
104
+ self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
105
+ ) # Slider for confidence
106
+ self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
107
+
108
+ col1, col2 = self.st.columns(2)
109
+ self.org_frame = col1.empty()
110
+ self.ann_frame = col2.empty()
111
+
112
+ def source_upload(self):
113
+ """Handles video file uploads through the Streamlit interface."""
114
+ self.vid_file_name = ""
115
+ if self.source == "video":
116
+ vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
117
+ if vid_file is not None:
118
+ g = io.BytesIO(vid_file.read()) # BytesIO Object
119
+ with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
120
+ out.write(g.read()) # Read bytes into file
121
+ self.vid_file_name = "ultralytics.mp4"
122
+ elif self.source == "webcam":
123
+ self.vid_file_name = 0
124
+
125
+ def configure(self):
126
+ """Configures the model and loads selected classes for inference."""
127
+ # Add dropdown menu for model selection
128
+ available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
129
+ if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
130
+ available_models.insert(0, self.model_path.split(".pt")[0])
131
+ selected_model = self.st.sidebar.selectbox("Model", available_models)
132
+
133
+ with self.st.spinner("Model is downloading..."):
134
+ self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
135
+ class_names = list(self.model.names.values()) # Convert dictionary to list of class names
136
+ self.st.success("Model loaded successfully!")
137
+
138
+ # Multiselect box with class names and get indices of selected classes
139
+ selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
140
+ self.selected_ind = [class_names.index(option) for option in selected_classes]
141
+
142
+ if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
143
+ self.selected_ind = list(self.selected_ind)
144
+
145
+ def inference(self):
146
+ """Performs real-time object detection inference."""
147
+ self.web_ui() # Initialize the web interface
148
+ self.sidebar() # Create the sidebar
149
+ self.source_upload() # Upload the video source
150
+ self.configure() # Configure the app
151
+
152
+ if self.st.sidebar.button("Start"):
153
+ stop_button = self.st.button("Stop") # Button to stop the inference
154
+ cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
155
+ if not cap.isOpened():
156
+ self.st.error("Could not open webcam.")
157
+ while cap.isOpened():
158
+ success, frame = cap.read()
159
+ if not success:
160
+ self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
161
+ break
162
+
163
+ # Store model predictions
164
+ if self.enable_trk == "Yes":
165
+ results = self.model.track(
166
+ frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
167
+ )
168
+ else:
169
+ results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
170
+ annotated_frame = results[0].plot() # Add annotations on frame
171
+
172
+ if stop_button:
173
+ cap.release() # Release the capture
174
+ self.st.stop() # Stop streamlit app
175
+
176
+ self.org_frame.image(frame, channels="BGR") # Display original frame
177
+ self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
178
+
179
+ cap.release() # Release the capture
180
+ cv2.destroyAllWindows() # Destroy window
181
+
182
+
183
+ if __name__ == "__main__":
184
+ import sys # Import the sys module for accessing command-line arguments
185
+
186
+ # Check if a model name is provided as a command-line argument
187
+ args = len(sys.argv)
188
+ model = sys.argv[1] if args > 1 else None # assign first argument as the model name
189
+ # Create an instance of the Inference class and run inference
190
+ Inference(model=model).inference()