dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,508 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
+
5
+ Usage - sources:
6
+ $ yolo mode=predict model=yolo11n.pt source=0 # webcam
7
+ img.jpg # image
8
+ vid.mp4 # video
9
+ screen # screenshot
10
+ path/ # directory
11
+ list.txt # list of images
12
+ list.streams # list of streams
13
+ 'path/*.jpg' # glob
14
+ 'https://youtu.be/LNwODJXcvt4' # YouTube
15
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
16
+
17
+ Usage - formats:
18
+ $ yolo mode=predict model=yolo11n.pt # PyTorch
19
+ yolo11n.torchscript # TorchScript
20
+ yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
21
+ yolo11n_openvino_model # OpenVINO
22
+ yolo11n.engine # TensorRT
23
+ yolo11n.mlpackage # CoreML (macOS-only)
24
+ yolo11n_saved_model # TensorFlow SavedModel
25
+ yolo11n.pb # TensorFlow GraphDef
26
+ yolo11n.tflite # TensorFlow Lite
27
+ yolo11n_edgetpu.tflite # TensorFlow Edge TPU
28
+ yolo11n_paddle_model # PaddlePaddle
29
+ yolo11n.mnn # MNN
30
+ yolo11n_ncnn_model # NCNN
31
+ yolo11n_imx_model # Sony IMX
32
+ yolo11n_rknn_model # Rockchip RKNN
33
+ yolo11n.pte # PyTorch Executorch
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import platform
39
+ import re
40
+ import threading
41
+ from pathlib import Path
42
+ from typing import Any
43
+
44
+ import cv2
45
+ import numpy as np
46
+ import torch
47
+
48
+ from ultralytics.cfg import get_cfg, get_save_dir
49
+ from ultralytics.data import load_inference_source
50
+ from ultralytics.data.augment import LetterBox
51
+ from ultralytics.nn.autobackend import AutoBackend
52
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
53
+ from ultralytics.utils.checks import check_imgsz, check_imshow
54
+ from ultralytics.utils.files import increment_path
55
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
56
+
57
+ STREAM_WARNING = """
58
+ Inference results will accumulate in RAM unless `stream=True` is passed, which can cause out-of-memory errors for large
59
+ sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
60
+
61
+ Example:
62
+ results = model(source=..., stream=True) # generator of Results objects
63
+ for r in results:
64
+ boxes = r.boxes # Boxes object for bbox outputs
65
+ masks = r.masks # Masks object for segment masks outputs
66
+ probs = r.probs # Class probabilities for classification outputs
67
+ """
68
+
69
+
70
+ class BasePredictor:
71
+ """A base class for creating predictors.
72
+
73
+ This class provides the foundation for prediction functionality, handling model setup, inference, and result
74
+ processing across various input sources.
75
+
76
+ Attributes:
77
+ args (SimpleNamespace): Configuration for the predictor.
78
+ save_dir (Path): Directory to save results.
79
+ done_warmup (bool): Whether the predictor has finished setup.
80
+ model (torch.nn.Module): Model used for prediction.
81
+ data (dict): Data configuration.
82
+ device (torch.device): Device used for prediction.
83
+ dataset (Dataset): Dataset used for prediction.
84
+ vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
85
+ plotted_img (np.ndarray): Last plotted image.
86
+ source_type (SimpleNamespace): Type of input source.
87
+ seen (int): Number of images processed.
88
+ windows (list[str]): List of window names for visualization.
89
+ batch (tuple): Current batch data.
90
+ results (list[Any]): Current batch results.
91
+ transforms (callable): Image transforms for classification.
92
+ callbacks (dict[str, list[callable]]): Callback functions for different events.
93
+ txt_path (Path): Path to save text results.
94
+ _lock (threading.Lock): Lock for thread-safe inference.
95
+
96
+ Methods:
97
+ preprocess: Prepare input image before inference.
98
+ inference: Run inference on a given image.
99
+ postprocess: Process raw predictions into structured results.
100
+ predict_cli: Run prediction for command line interface.
101
+ setup_source: Set up input source and inference mode.
102
+ stream_inference: Stream inference on input source.
103
+ setup_model: Initialize and configure the model.
104
+ write_results: Write inference results to files.
105
+ save_predicted_images: Save prediction visualizations.
106
+ show: Display results in a window.
107
+ run_callbacks: Execute registered callbacks for an event.
108
+ add_callback: Register a new callback function.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ cfg=DEFAULT_CFG,
114
+ overrides: dict[str, Any] | None = None,
115
+ _callbacks: dict[str, list[callable]] | None = None,
116
+ ):
117
+ """Initialize the BasePredictor class.
118
+
119
+ Args:
120
+ cfg (str | dict): Path to a configuration file or a configuration dictionary.
121
+ overrides (dict, optional): Configuration overrides.
122
+ _callbacks (dict, optional): Dictionary of callback functions.
123
+ """
124
+ self.args = get_cfg(cfg, overrides)
125
+ self.save_dir = get_save_dir(self.args)
126
+ if self.args.conf is None:
127
+ self.args.conf = 0.25 # default conf=0.25
128
+ self.done_warmup = False
129
+ if self.args.show:
130
+ self.args.show = check_imshow(warn=True)
131
+
132
+ # Usable if setup is done
133
+ self.model = None
134
+ self.data = self.args.data # data_dict
135
+ self.imgsz = None
136
+ self.device = None
137
+ self.dataset = None
138
+ self.vid_writer = {} # dict of {save_path: video_writer, ...}
139
+ self.plotted_img = None
140
+ self.source_type = None
141
+ self.seen = 0
142
+ self.windows = []
143
+ self.batch = None
144
+ self.results = None
145
+ self.transforms = None
146
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
147
+ self.txt_path = None
148
+ self._lock = threading.Lock() # for automatic thread-safe inference
149
+ callbacks.add_integration_callbacks(self)
150
+
151
+ def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
152
+ """Prepare input image before inference.
153
+
154
+ Args:
155
+ im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
156
+
157
+ Returns:
158
+ (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
159
+ """
160
+ not_tensor = not isinstance(im, torch.Tensor)
161
+ if not_tensor:
162
+ im = np.stack(self.pre_transform(im))
163
+ if im.shape[-1] == 3:
164
+ im = im[..., ::-1] # BGR to RGB
165
+ im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
166
+ im = np.ascontiguousarray(im) # contiguous
167
+ im = torch.from_numpy(im)
168
+
169
+ im = im.to(self.device)
170
+ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
171
+ if not_tensor:
172
+ im /= 255 # 0 - 255 to 0.0 - 1.0
173
+ return im
174
+
175
+ def inference(self, im: torch.Tensor, *args, **kwargs):
176
+ """Run inference on a given image using the specified model and arguments."""
177
+ visualize = (
178
+ increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
179
+ if self.args.visualize and (not self.source_type.tensor)
180
+ else False
181
+ )
182
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
183
+
184
+ def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
185
+ """Pre-transform input image before inference.
186
+
187
+ Args:
188
+ im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
189
+
190
+ Returns:
191
+ (list[np.ndarray]): List of transformed images.
192
+ """
193
+ same_shapes = len({x.shape for x in im}) == 1
194
+ letterbox = LetterBox(
195
+ self.imgsz,
196
+ auto=same_shapes
197
+ and self.args.rect
198
+ and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
199
+ stride=self.model.stride,
200
+ )
201
+ return [letterbox(image=x) for x in im]
202
+
203
+ def postprocess(self, preds, img, orig_imgs):
204
+ """Post-process predictions for an image and return them."""
205
+ return preds
206
+
207
+ def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
208
+ """Perform inference on an image or stream.
209
+
210
+ Args:
211
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
212
+ Source for inference.
213
+ model (str | Path | torch.nn.Module, optional): Model for inference.
214
+ stream (bool): Whether to stream the inference results. If True, returns a generator.
215
+ *args (Any): Additional arguments for the inference method.
216
+ **kwargs (Any): Additional keyword arguments for the inference method.
217
+
218
+ Returns:
219
+ (list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
220
+ """
221
+ self.stream = stream
222
+ if stream:
223
+ return self.stream_inference(source, model, *args, **kwargs)
224
+ else:
225
+ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Results into one
226
+
227
+ def predict_cli(self, source=None, model=None):
228
+ """Method used for Command Line Interface (CLI) prediction.
229
+
230
+ This function is designed to run predictions using the CLI. It sets up the source and model, then processes the
231
+ inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
232
+ generator without storing results.
233
+
234
+ Args:
235
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
236
+ Source for inference.
237
+ model (str | Path | torch.nn.Module, optional): Model for inference.
238
+
239
+ Notes:
240
+ Do not modify this function or remove the generator. The generator ensures that no outputs are
241
+ accumulated in memory, which is critical for preventing memory issues during long-running predictions.
242
+ """
243
+ gen = self.stream_inference(source, model)
244
+ for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
245
+ pass
246
+
247
+ def setup_source(self, source, stride: int | None = None):
248
+ """Set up source and inference mode.
249
+
250
+ Args:
251
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
252
+ inference.
253
+ stride (int, optional): Model stride for image size checking.
254
+ """
255
+ self.imgsz = check_imgsz(self.args.imgsz, stride=stride or self.model.stride, min_dim=2) # check image size
256
+ self.dataset = load_inference_source(
257
+ source=source,
258
+ batch=self.args.batch,
259
+ vid_stride=self.args.vid_stride,
260
+ buffer=self.args.stream_buffer,
261
+ channels=getattr(self.model, "ch", 3),
262
+ )
263
+ self.source_type = self.dataset.source_type
264
+ if (
265
+ self.source_type.stream
266
+ or self.source_type.screenshot
267
+ or len(self.dataset) > 1000 # many images
268
+ or any(getattr(self.dataset, "video_flag", [False]))
269
+ ): # long sequence
270
+ import torchvision # noqa (import here triggers torchvision NMS use in nms.py)
271
+
272
+ if not getattr(self, "stream", True): # videos
273
+ LOGGER.warning(STREAM_WARNING)
274
+ self.vid_writer = {}
275
+
276
+ @smart_inference_mode()
277
+ def stream_inference(self, source=None, model=None, *args, **kwargs):
278
+ """Stream real-time inference on camera feed and save results to file.
279
+
280
+ Args:
281
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
282
+ Source for inference.
283
+ model (str | Path | torch.nn.Module, optional): Model for inference.
284
+ *args (Any): Additional arguments for the inference method.
285
+ **kwargs (Any): Additional keyword arguments for the inference method.
286
+
287
+ Yields:
288
+ (ultralytics.engine.results.Results): Results objects.
289
+ """
290
+ if self.args.verbose:
291
+ LOGGER.info("")
292
+
293
+ # Setup model
294
+ if not self.model:
295
+ self.setup_model(model)
296
+
297
+ with self._lock: # for thread-safe inference
298
+ # Setup source every time predict is called
299
+ self.setup_source(source if source is not None else self.args.source)
300
+
301
+ # Check if save_dir/ label file exists
302
+ if self.args.save or self.args.save_txt:
303
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
304
+
305
+ # Warmup model
306
+ if not self.done_warmup:
307
+ self.model.warmup(
308
+ imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz)
309
+ )
310
+ self.done_warmup = True
311
+
312
+ self.seen, self.windows, self.batch = 0, [], None
313
+ profilers = (
314
+ ops.Profile(device=self.device),
315
+ ops.Profile(device=self.device),
316
+ ops.Profile(device=self.device),
317
+ )
318
+ self.run_callbacks("on_predict_start")
319
+ for batch in self.dataset:
320
+ self.batch = batch
321
+ self.run_callbacks("on_predict_batch_start")
322
+ paths, im0s, s = self.batch
323
+
324
+ # Preprocess
325
+ with profilers[0]:
326
+ im = self.preprocess(im0s)
327
+
328
+ # Inference
329
+ with profilers[1]:
330
+ preds = self.inference(im, *args, **kwargs)
331
+ if self.args.embed:
332
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
333
+ continue
334
+
335
+ # Postprocess
336
+ with profilers[2]:
337
+ self.results = self.postprocess(preds, im, im0s)
338
+ self.run_callbacks("on_predict_postprocess_end")
339
+
340
+ # Visualize, save, write results
341
+ n = len(im0s)
342
+ try:
343
+ for i in range(n):
344
+ self.seen += 1
345
+ self.results[i].speed = {
346
+ "preprocess": profilers[0].dt * 1e3 / n,
347
+ "inference": profilers[1].dt * 1e3 / n,
348
+ "postprocess": profilers[2].dt * 1e3 / n,
349
+ }
350
+ if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
351
+ s[i] += self.write_results(i, Path(paths[i]), im, s)
352
+ except StopIteration:
353
+ break
354
+
355
+ # Print batch results
356
+ if self.args.verbose:
357
+ LOGGER.info("\n".join(s))
358
+
359
+ self.run_callbacks("on_predict_batch_end")
360
+ yield from self.results
361
+
362
+ # Release assets
363
+ for v in self.vid_writer.values():
364
+ if isinstance(v, cv2.VideoWriter):
365
+ v.release()
366
+
367
+ if self.args.show:
368
+ cv2.destroyAllWindows() # close any open windows
369
+
370
+ # Print final results
371
+ if self.args.verbose and self.seen:
372
+ t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
373
+ LOGGER.info(
374
+ f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
375
+ f"{(min(self.args.batch, self.seen), getattr(self.model, 'ch', 3), *im.shape[2:])}" % t
376
+ )
377
+ if self.args.save or self.args.save_txt or self.args.save_crop:
378
+ nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
379
+ s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
380
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
381
+ self.run_callbacks("on_predict_end")
382
+
383
+ def setup_model(self, model, verbose: bool = True):
384
+ """Initialize YOLO model with given parameters and set it to evaluation mode.
385
+
386
+ Args:
387
+ model (str | Path | torch.nn.Module, optional): Model to load or use.
388
+ verbose (bool): Whether to print verbose output.
389
+ """
390
+ self.model = AutoBackend(
391
+ model=model or self.args.model,
392
+ device=select_device(self.args.device, verbose=verbose),
393
+ dnn=self.args.dnn,
394
+ data=self.args.data,
395
+ fp16=self.args.half,
396
+ fuse=True,
397
+ verbose=verbose,
398
+ )
399
+
400
+ self.device = self.model.device # update device
401
+ self.args.half = self.model.fp16 # update half
402
+ if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
403
+ self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
404
+ self.model.eval()
405
+ self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
406
+
407
+ def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
408
+ """Write inference results to a file or directory.
409
+
410
+ Args:
411
+ i (int): Index of the current image in the batch.
412
+ p (Path): Path to the current image.
413
+ im (torch.Tensor): Preprocessed image tensor.
414
+ s (list[str]): List of result strings.
415
+
416
+ Returns:
417
+ (str): String with result information.
418
+ """
419
+ string = "" # print string
420
+ if len(im.shape) == 3:
421
+ im = im[None] # expand for batch dim
422
+ if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
423
+ string += f"{i}: "
424
+ frame = self.dataset.count
425
+ else:
426
+ match = re.search(r"frame (\d+)/", s[i])
427
+ frame = int(match[1]) if match else None # 0 if frame undetermined
428
+
429
+ self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
430
+ string += "{:g}x{:g} ".format(*im.shape[2:])
431
+ result = self.results[i]
432
+ result.save_dir = self.save_dir.__str__() # used in other locations
433
+ string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
434
+
435
+ # Add predictions to image
436
+ if self.args.save or self.args.show:
437
+ self.plotted_img = result.plot(
438
+ line_width=self.args.line_width,
439
+ boxes=self.args.show_boxes,
440
+ conf=self.args.show_conf,
441
+ labels=self.args.show_labels,
442
+ im_gpu=None if self.args.retina_masks else im[i],
443
+ )
444
+
445
+ # Save results
446
+ if self.args.save_txt:
447
+ result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
448
+ if self.args.save_crop:
449
+ result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
450
+ if self.args.show:
451
+ self.show(str(p))
452
+ if self.args.save:
453
+ self.save_predicted_images(self.save_dir / p.name, frame)
454
+
455
+ return string
456
+
457
+ def save_predicted_images(self, save_path: Path, frame: int = 0):
458
+ """Save video predictions as mp4 or images as jpg at specified path.
459
+
460
+ Args:
461
+ save_path (Path): Path to save the results.
462
+ frame (int): Frame number for video mode.
463
+ """
464
+ im = self.plotted_img
465
+
466
+ # Save videos and streams
467
+ if self.dataset.mode in {"stream", "video"}:
468
+ fps = self.dataset.fps if self.dataset.mode == "video" else 30
469
+ frames_path = self.save_dir / f"{save_path.stem}_frames" # save frames to a separate directory
470
+ if save_path not in self.vid_writer: # new video
471
+ if self.args.save_frames:
472
+ Path(frames_path).mkdir(parents=True, exist_ok=True)
473
+ suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
474
+ self.vid_writer[save_path] = cv2.VideoWriter(
475
+ filename=str(Path(save_path).with_suffix(suffix)),
476
+ fourcc=cv2.VideoWriter_fourcc(*fourcc),
477
+ fps=fps, # integer required, floats produce error in MP4 codec
478
+ frameSize=(im.shape[1], im.shape[0]), # (width, height)
479
+ )
480
+
481
+ # Save video
482
+ self.vid_writer[save_path].write(im)
483
+ if self.args.save_frames:
484
+ cv2.imwrite(f"{frames_path}/{save_path.stem}_{frame}.jpg", im)
485
+
486
+ # Save images
487
+ else:
488
+ cv2.imwrite(str(save_path.with_suffix(".jpg")), im) # save to JPG for best support
489
+
490
+ def show(self, p: str = ""):
491
+ """Display an image in a window."""
492
+ im = self.plotted_img
493
+ if platform.system() == "Linux" and p not in self.windows:
494
+ self.windows.append(p)
495
+ cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
496
+ cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
497
+ cv2.imshow(p, im)
498
+ if cv2.waitKey(300 if self.dataset.mode == "image" else 1) & 0xFF == ord("q"): # 300ms if image; else 1ms
499
+ raise StopIteration
500
+
501
+ def run_callbacks(self, event: str):
502
+ """Run all registered callbacks for a specific event."""
503
+ for callback in self.callbacks.get(event, []):
504
+ callback(self)
505
+
506
+ def add_callback(self, event: str, func: callable):
507
+ """Add a callback function for a specific event."""
508
+ self.callbacks[event].append(func)