dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -30,12 +30,16 @@ Usage - formats:
30
30
  yolo11n_ncnn_model # NCNN
31
31
  yolo11n_imx_model # Sony IMX
32
32
  yolo11n_rknn_model # Rockchip RKNN
33
+ yolo11n.pte # PyTorch Executorch
33
34
  """
34
35
 
36
+ from __future__ import annotations
37
+
35
38
  import platform
36
39
  import re
37
40
  import threading
38
41
  from pathlib import Path
42
+ from typing import Any
39
43
 
40
44
  import cv2
41
45
  import numpy as np
@@ -43,12 +47,12 @@ import torch
43
47
 
44
48
  from ultralytics.cfg import get_cfg, get_save_dir
45
49
  from ultralytics.data import load_inference_source
46
- from ultralytics.data.augment import LetterBox, classify_transforms
50
+ from ultralytics.data.augment import LetterBox
47
51
  from ultralytics.nn.autobackend import AutoBackend
48
52
  from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
49
53
  from ultralytics.utils.checks import check_imgsz, check_imshow
50
54
  from ultralytics.utils.files import increment_path
51
- from ultralytics.utils.torch_utils import select_device, smart_inference_mode
55
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
52
56
 
53
57
  STREAM_WARNING = """
54
58
  inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
@@ -64,11 +68,10 @@ Example:
64
68
 
65
69
 
66
70
  class BasePredictor:
67
- """
68
- A base class for creating predictors.
71
+ """A base class for creating predictors.
69
72
 
70
- This class provides the foundation for prediction functionality, handling model setup, inference,
71
- and result processing across various input sources.
73
+ This class provides the foundation for prediction functionality, handling model setup, inference, and result
74
+ processing across various input sources.
72
75
 
73
76
  Attributes:
74
77
  args (SimpleNamespace): Configuration for the predictor.
@@ -78,15 +81,15 @@ class BasePredictor:
78
81
  data (dict): Data configuration.
79
82
  device (torch.device): Device used for prediction.
80
83
  dataset (Dataset): Dataset used for prediction.
81
- vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
82
- plotted_img (numpy.ndarray): Last plotted image.
84
+ vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
85
+ plotted_img (np.ndarray): Last plotted image.
83
86
  source_type (SimpleNamespace): Type of input source.
84
87
  seen (int): Number of images processed.
85
- windows (list): List of window names for visualization.
88
+ windows (list[str]): List of window names for visualization.
86
89
  batch (tuple): Current batch data.
87
- results (list): Current batch results.
90
+ results (list[Any]): Current batch results.
88
91
  transforms (callable): Image transforms for classification.
89
- callbacks (dict): Callback functions for different events.
92
+ callbacks (dict[str, list[callable]]): Callback functions for different events.
90
93
  txt_path (Path): Path to save text results.
91
94
  _lock (threading.Lock): Lock for thread-safe inference.
92
95
 
@@ -105,14 +108,18 @@ class BasePredictor:
105
108
  add_callback: Register a new callback function.
106
109
  """
107
110
 
108
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
109
- """
110
- Initialize the BasePredictor class.
111
+ def __init__(
112
+ self,
113
+ cfg=DEFAULT_CFG,
114
+ overrides: dict[str, Any] | None = None,
115
+ _callbacks: dict[str, list[callable]] | None = None,
116
+ ):
117
+ """Initialize the BasePredictor class.
111
118
 
112
119
  Args:
113
120
  cfg (str | dict): Path to a configuration file or a configuration dictionary.
114
- overrides (dict | None): Configuration overrides.
115
- _callbacks (dict | None): Dictionary of callback functions.
121
+ overrides (dict, optional): Configuration overrides.
122
+ _callbacks (dict, optional): Dictionary of callback functions.
116
123
  """
117
124
  self.args = get_cfg(cfg, overrides)
118
125
  self.save_dir = get_save_dir(self.args)
@@ -141,12 +148,14 @@ class BasePredictor:
141
148
  self._lock = threading.Lock() # for automatic thread-safe inference
142
149
  callbacks.add_integration_callbacks(self)
143
150
 
144
- def preprocess(self, im):
145
- """
146
- Prepares input image before inference.
151
+ def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
152
+ """Prepare input image before inference.
147
153
 
148
154
  Args:
149
- im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
155
+ im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
156
+
157
+ Returns:
158
+ (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
150
159
  """
151
160
  not_tensor = not isinstance(im, torch.Tensor)
152
161
  if not_tensor:
@@ -163,7 +172,7 @@ class BasePredictor:
163
172
  im /= 255 # 0 - 255 to 0.0 - 1.0
164
173
  return im
165
174
 
166
- def inference(self, im, *args, **kwargs):
175
+ def inference(self, im: torch.Tensor, *args, **kwargs):
167
176
  """Run inference on a given image using the specified model and arguments."""
168
177
  visualize = (
169
178
  increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
@@ -172,15 +181,14 @@ class BasePredictor:
172
181
  )
173
182
  return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
174
183
 
175
- def pre_transform(self, im):
176
- """
177
- Pre-transform input image before inference.
184
+ def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
185
+ """Pre-transform input image before inference.
178
186
 
179
187
  Args:
180
- im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
188
+ im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
181
189
 
182
190
  Returns:
183
- (List[np.ndarray]): A list of transformed images.
191
+ (list[np.ndarray]): List of transformed images.
184
192
  """
185
193
  same_shapes = len({x.shape for x in im}) == 1
186
194
  letterbox = LetterBox(
@@ -196,20 +204,19 @@ class BasePredictor:
196
204
  """Post-process predictions for an image and return them."""
197
205
  return preds
198
206
 
199
- def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
200
- """
201
- Perform inference on an image or stream.
207
+ def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
208
+ """Perform inference on an image or stream.
202
209
 
203
210
  Args:
204
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
211
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
205
212
  Source for inference.
206
- model (str | Path | torch.nn.Module | None): Model for inference.
213
+ model (str | Path | torch.nn.Module, optional): Model for inference.
207
214
  stream (bool): Whether to stream the inference results. If True, returns a generator.
208
215
  *args (Any): Additional arguments for the inference method.
209
216
  **kwargs (Any): Additional keyword arguments for the inference method.
210
217
 
211
218
  Returns:
212
- (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
219
+ (list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
213
220
  """
214
221
  self.stream = stream
215
222
  if stream:
@@ -218,19 +225,18 @@ class BasePredictor:
218
225
  return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
219
226
 
220
227
  def predict_cli(self, source=None, model=None):
221
- """
222
- Method used for Command Line Interface (CLI) prediction.
228
+ """Method used for Command Line Interface (CLI) prediction.
223
229
 
224
- This function is designed to run predictions using the CLI. It sets up the source and model, then processes
225
- the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
230
+ This function is designed to run predictions using the CLI. It sets up the source and model, then processes the
231
+ inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
226
232
  generator without storing results.
227
233
 
228
234
  Args:
229
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
235
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
230
236
  Source for inference.
231
- model (str | Path | torch.nn.Module | None): Model for inference.
237
+ model (str | Path | torch.nn.Module, optional): Model for inference.
232
238
 
233
- Note:
239
+ Notes:
234
240
  Do not modify this function or remove the generator. The generator ensures that no outputs are
235
241
  accumulated in memory, which is critical for preventing memory issues during long-running predictions.
236
242
  """
@@ -239,23 +245,13 @@ class BasePredictor:
239
245
  pass
240
246
 
241
247
  def setup_source(self, source):
242
- """
243
- Set up source and inference mode.
248
+ """Set up source and inference mode.
244
249
 
245
250
  Args:
246
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
247
- Source for inference.
251
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
252
+ inference.
248
253
  """
249
254
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
250
- self.transforms = (
251
- getattr(
252
- self.model.model,
253
- "transforms",
254
- classify_transforms(self.imgsz[0]),
255
- )
256
- if self.args.task == "classify"
257
- else None
258
- )
259
255
  self.dataset = load_inference_source(
260
256
  source=source,
261
257
  batch=self.args.batch,
@@ -264,24 +260,27 @@ class BasePredictor:
264
260
  channels=getattr(self.model, "ch", 3),
265
261
  )
266
262
  self.source_type = self.dataset.source_type
267
- if not getattr(self, "stream", True) and (
263
+ long_sequence = (
268
264
  self.source_type.stream
269
265
  or self.source_type.screenshot
270
266
  or len(self.dataset) > 1000 # many images
271
267
  or any(getattr(self.dataset, "video_flag", [False]))
272
- ): # videos
273
- LOGGER.warning(STREAM_WARNING)
268
+ )
269
+ if long_sequence:
270
+ import torchvision # noqa (import here triggers torchvision NMS use in nms.py)
271
+
272
+ if not getattr(self, "stream", True): # videos
273
+ LOGGER.warning(STREAM_WARNING)
274
274
  self.vid_writer = {}
275
275
 
276
276
  @smart_inference_mode()
277
277
  def stream_inference(self, source=None, model=None, *args, **kwargs):
278
- """
279
- Stream real-time inference on camera feed and save results to file.
278
+ """Stream real-time inference on camera feed and save results to file.
280
279
 
281
280
  Args:
282
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
281
+ source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
283
282
  Source for inference.
284
- model (str | Path | torch.nn.Module | None): Model for inference.
283
+ model (str | Path | torch.nn.Module, optional): Model for inference.
285
284
  *args (Any): Additional arguments for the inference method.
286
285
  **kwargs (Any): Additional keyword arguments for the inference method.
287
286
 
@@ -339,15 +338,18 @@ class BasePredictor:
339
338
 
340
339
  # Visualize, save, write results
341
340
  n = len(im0s)
342
- for i in range(n):
343
- self.seen += 1
344
- self.results[i].speed = {
345
- "preprocess": profilers[0].dt * 1e3 / n,
346
- "inference": profilers[1].dt * 1e3 / n,
347
- "postprocess": profilers[2].dt * 1e3 / n,
348
- }
349
- if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
350
- s[i] += self.write_results(i, Path(paths[i]), im, s)
341
+ try:
342
+ for i in range(n):
343
+ self.seen += 1
344
+ self.results[i].speed = {
345
+ "preprocess": profilers[0].dt * 1e3 / n,
346
+ "inference": profilers[1].dt * 1e3 / n,
347
+ "postprocess": profilers[2].dt * 1e3 / n,
348
+ }
349
+ if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
350
+ s[i] += self.write_results(i, Path(paths[i]), im, s)
351
+ except StopIteration:
352
+ break
351
353
 
352
354
  # Print batch results
353
355
  if self.args.verbose:
@@ -361,6 +363,9 @@ class BasePredictor:
361
363
  if isinstance(v, cv2.VideoWriter):
362
364
  v.release()
363
365
 
366
+ if self.args.show:
367
+ cv2.destroyAllWindows() # close any open windows
368
+
364
369
  # Print final results
365
370
  if self.args.verbose and self.seen:
366
371
  t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
@@ -374,38 +379,38 @@ class BasePredictor:
374
379
  LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
375
380
  self.run_callbacks("on_predict_end")
376
381
 
377
- def setup_model(self, model, verbose=True):
378
- """
379
- Initialize YOLO model with given parameters and set it to evaluation mode.
382
+ def setup_model(self, model, verbose: bool = True):
383
+ """Initialize YOLO model with given parameters and set it to evaluation mode.
380
384
 
381
385
  Args:
382
- model (str | Path | torch.nn.Module | None): Model to load or use.
386
+ model (str | Path | torch.nn.Module, optional): Model to load or use.
383
387
  verbose (bool): Whether to print verbose output.
384
388
  """
385
389
  self.model = AutoBackend(
386
- weights=model or self.args.model,
390
+ model=model or self.args.model,
387
391
  device=select_device(self.args.device, verbose=verbose),
388
392
  dnn=self.args.dnn,
389
393
  data=self.args.data,
390
394
  fp16=self.args.half,
391
- batch=self.args.batch,
392
395
  fuse=True,
393
396
  verbose=verbose,
394
397
  )
395
398
 
396
399
  self.device = self.model.device # update device
397
400
  self.args.half = self.model.fp16 # update half
401
+ if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
402
+ self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
398
403
  self.model.eval()
404
+ self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
399
405
 
400
- def write_results(self, i, p, im, s):
401
- """
402
- Write inference results to a file or directory.
406
+ def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
407
+ """Write inference results to a file or directory.
403
408
 
404
409
  Args:
405
410
  i (int): Index of the current image in the batch.
406
411
  p (Path): Path to the current image.
407
412
  im (torch.Tensor): Preprocessed image tensor.
408
- s (List[str]): List of result strings.
413
+ s (list[str]): List of result strings.
409
414
 
410
415
  Returns:
411
416
  (str): String with result information.
@@ -444,16 +449,15 @@ class BasePredictor:
444
449
  if self.args.show:
445
450
  self.show(str(p))
446
451
  if self.args.save:
447
- self.save_predicted_images(str(self.save_dir / p.name), frame)
452
+ self.save_predicted_images(self.save_dir / p.name, frame)
448
453
 
449
454
  return string
450
455
 
451
- def save_predicted_images(self, save_path="", frame=0):
452
- """
453
- Save video predictions as mp4 or images as jpg at specified path.
456
+ def save_predicted_images(self, save_path: Path, frame: int = 0):
457
+ """Save video predictions as mp4 or images as jpg at specified path.
454
458
 
455
459
  Args:
456
- save_path (str): Path to save the results.
460
+ save_path (Path): Path to save the results.
457
461
  frame (int): Frame number for video mode.
458
462
  """
459
463
  im = self.plotted_img
@@ -461,7 +465,7 @@ class BasePredictor:
461
465
  # Save videos and streams
462
466
  if self.dataset.mode in {"stream", "video"}:
463
467
  fps = self.dataset.fps if self.dataset.mode == "video" else 30
464
- frames_path = f"{save_path.split('.', 1)[0]}_frames/"
468
+ frames_path = self.save_dir / f"{save_path.stem}_frames" # save frames to a separate directory
465
469
  if save_path not in self.vid_writer: # new video
466
470
  if self.args.save_frames:
467
471
  Path(frames_path).mkdir(parents=True, exist_ok=True)
@@ -476,13 +480,13 @@ class BasePredictor:
476
480
  # Save video
477
481
  self.vid_writer[save_path].write(im)
478
482
  if self.args.save_frames:
479
- cv2.imwrite(f"{frames_path}{frame}.jpg", im)
483
+ cv2.imwrite(f"{frames_path}/{save_path.stem}_{frame}.jpg", im)
480
484
 
481
485
  # Save images
482
486
  else:
483
- cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
487
+ cv2.imwrite(str(save_path.with_suffix(".jpg")), im) # save to JPG for best support
484
488
 
485
- def show(self, p=""):
489
+ def show(self, p: str = ""):
486
490
  """Display an image in a window."""
487
491
  im = self.plotted_img
488
492
  if platform.system() == "Linux" and p not in self.windows:
@@ -490,13 +494,14 @@ class BasePredictor:
490
494
  cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
491
495
  cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
492
496
  cv2.imshow(p, im)
493
- cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
497
+ if cv2.waitKey(300 if self.dataset.mode == "image" else 1) & 0xFF == ord("q"): # 300ms if image; else 1ms
498
+ raise StopIteration
494
499
 
495
500
  def run_callbacks(self, event: str):
496
501
  """Run all registered callbacks for a specific event."""
497
502
  for callback in self.callbacks.get(event, []):
498
503
  callback(self)
499
504
 
500
- def add_callback(self, event: str, func):
505
+ def add_callback(self, event: str, func: callable):
501
506
  """Add a callback function for a specific event."""
502
507
  self.callbacks[event].append(func)