dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/engine/predictor.py
CHANGED
|
@@ -30,12 +30,16 @@ Usage - formats:
|
|
|
30
30
|
yolo11n_ncnn_model # NCNN
|
|
31
31
|
yolo11n_imx_model # Sony IMX
|
|
32
32
|
yolo11n_rknn_model # Rockchip RKNN
|
|
33
|
+
yolo11n.pte # PyTorch Executorch
|
|
33
34
|
"""
|
|
34
35
|
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
35
38
|
import platform
|
|
36
39
|
import re
|
|
37
40
|
import threading
|
|
38
41
|
from pathlib import Path
|
|
42
|
+
from typing import Any
|
|
39
43
|
|
|
40
44
|
import cv2
|
|
41
45
|
import numpy as np
|
|
@@ -43,12 +47,12 @@ import torch
|
|
|
43
47
|
|
|
44
48
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
45
49
|
from ultralytics.data import load_inference_source
|
|
46
|
-
from ultralytics.data.augment import LetterBox
|
|
50
|
+
from ultralytics.data.augment import LetterBox
|
|
47
51
|
from ultralytics.nn.autobackend import AutoBackend
|
|
48
52
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
|
49
53
|
from ultralytics.utils.checks import check_imgsz, check_imshow
|
|
50
54
|
from ultralytics.utils.files import increment_path
|
|
51
|
-
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
|
55
|
+
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
|
|
52
56
|
|
|
53
57
|
STREAM_WARNING = """
|
|
54
58
|
inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
|
|
@@ -64,11 +68,10 @@ Example:
|
|
|
64
68
|
|
|
65
69
|
|
|
66
70
|
class BasePredictor:
|
|
67
|
-
"""
|
|
68
|
-
A base class for creating predictors.
|
|
71
|
+
"""A base class for creating predictors.
|
|
69
72
|
|
|
70
|
-
This class provides the foundation for prediction functionality, handling model setup, inference,
|
|
71
|
-
|
|
73
|
+
This class provides the foundation for prediction functionality, handling model setup, inference, and result
|
|
74
|
+
processing across various input sources.
|
|
72
75
|
|
|
73
76
|
Attributes:
|
|
74
77
|
args (SimpleNamespace): Configuration for the predictor.
|
|
@@ -78,15 +81,15 @@ class BasePredictor:
|
|
|
78
81
|
data (dict): Data configuration.
|
|
79
82
|
device (torch.device): Device used for prediction.
|
|
80
83
|
dataset (Dataset): Dataset used for prediction.
|
|
81
|
-
vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
|
|
82
|
-
plotted_img (
|
|
84
|
+
vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
|
|
85
|
+
plotted_img (np.ndarray): Last plotted image.
|
|
83
86
|
source_type (SimpleNamespace): Type of input source.
|
|
84
87
|
seen (int): Number of images processed.
|
|
85
|
-
windows (list): List of window names for visualization.
|
|
88
|
+
windows (list[str]): List of window names for visualization.
|
|
86
89
|
batch (tuple): Current batch data.
|
|
87
|
-
results (list): Current batch results.
|
|
90
|
+
results (list[Any]): Current batch results.
|
|
88
91
|
transforms (callable): Image transforms for classification.
|
|
89
|
-
callbacks (dict): Callback functions for different events.
|
|
92
|
+
callbacks (dict[str, list[callable]]): Callback functions for different events.
|
|
90
93
|
txt_path (Path): Path to save text results.
|
|
91
94
|
_lock (threading.Lock): Lock for thread-safe inference.
|
|
92
95
|
|
|
@@ -105,14 +108,18 @@ class BasePredictor:
|
|
|
105
108
|
add_callback: Register a new callback function.
|
|
106
109
|
"""
|
|
107
110
|
|
|
108
|
-
def __init__(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
cfg=DEFAULT_CFG,
|
|
114
|
+
overrides: dict[str, Any] | None = None,
|
|
115
|
+
_callbacks: dict[str, list[callable]] | None = None,
|
|
116
|
+
):
|
|
117
|
+
"""Initialize the BasePredictor class.
|
|
111
118
|
|
|
112
119
|
Args:
|
|
113
120
|
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
|
114
|
-
overrides (dict
|
|
115
|
-
_callbacks (dict
|
|
121
|
+
overrides (dict, optional): Configuration overrides.
|
|
122
|
+
_callbacks (dict, optional): Dictionary of callback functions.
|
|
116
123
|
"""
|
|
117
124
|
self.args = get_cfg(cfg, overrides)
|
|
118
125
|
self.save_dir = get_save_dir(self.args)
|
|
@@ -141,12 +148,14 @@ class BasePredictor:
|
|
|
141
148
|
self._lock = threading.Lock() # for automatic thread-safe inference
|
|
142
149
|
callbacks.add_integration_callbacks(self)
|
|
143
150
|
|
|
144
|
-
def preprocess(self, im):
|
|
145
|
-
"""
|
|
146
|
-
Prepares input image before inference.
|
|
151
|
+
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
|
|
152
|
+
"""Prepare input image before inference.
|
|
147
153
|
|
|
148
154
|
Args:
|
|
149
|
-
im (torch.Tensor |
|
|
155
|
+
im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
(torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
|
|
150
159
|
"""
|
|
151
160
|
not_tensor = not isinstance(im, torch.Tensor)
|
|
152
161
|
if not_tensor:
|
|
@@ -163,7 +172,7 @@ class BasePredictor:
|
|
|
163
172
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
|
164
173
|
return im
|
|
165
174
|
|
|
166
|
-
def inference(self, im, *args, **kwargs):
|
|
175
|
+
def inference(self, im: torch.Tensor, *args, **kwargs):
|
|
167
176
|
"""Run inference on a given image using the specified model and arguments."""
|
|
168
177
|
visualize = (
|
|
169
178
|
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
|
@@ -172,15 +181,14 @@ class BasePredictor:
|
|
|
172
181
|
)
|
|
173
182
|
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
|
174
183
|
|
|
175
|
-
def pre_transform(self, im):
|
|
176
|
-
"""
|
|
177
|
-
Pre-transform input image before inference.
|
|
184
|
+
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
|
|
185
|
+
"""Pre-transform input image before inference.
|
|
178
186
|
|
|
179
187
|
Args:
|
|
180
|
-
im (
|
|
188
|
+
im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
|
|
181
189
|
|
|
182
190
|
Returns:
|
|
183
|
-
(
|
|
191
|
+
(list[np.ndarray]): List of transformed images.
|
|
184
192
|
"""
|
|
185
193
|
same_shapes = len({x.shape for x in im}) == 1
|
|
186
194
|
letterbox = LetterBox(
|
|
@@ -196,20 +204,19 @@ class BasePredictor:
|
|
|
196
204
|
"""Post-process predictions for an image and return them."""
|
|
197
205
|
return preds
|
|
198
206
|
|
|
199
|
-
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
|
200
|
-
"""
|
|
201
|
-
Perform inference on an image or stream.
|
|
207
|
+
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
|
|
208
|
+
"""Perform inference on an image or stream.
|
|
202
209
|
|
|
203
210
|
Args:
|
|
204
|
-
source (str | Path |
|
|
211
|
+
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
|
205
212
|
Source for inference.
|
|
206
|
-
model (str | Path | torch.nn.Module
|
|
213
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
|
207
214
|
stream (bool): Whether to stream the inference results. If True, returns a generator.
|
|
208
215
|
*args (Any): Additional arguments for the inference method.
|
|
209
216
|
**kwargs (Any): Additional keyword arguments for the inference method.
|
|
210
217
|
|
|
211
218
|
Returns:
|
|
212
|
-
(
|
|
219
|
+
(list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
|
|
213
220
|
"""
|
|
214
221
|
self.stream = stream
|
|
215
222
|
if stream:
|
|
@@ -218,19 +225,18 @@ class BasePredictor:
|
|
|
218
225
|
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
|
219
226
|
|
|
220
227
|
def predict_cli(self, source=None, model=None):
|
|
221
|
-
"""
|
|
222
|
-
Method used for Command Line Interface (CLI) prediction.
|
|
228
|
+
"""Method used for Command Line Interface (CLI) prediction.
|
|
223
229
|
|
|
224
|
-
This function is designed to run predictions using the CLI. It sets up the source and model, then processes
|
|
225
|
-
|
|
230
|
+
This function is designed to run predictions using the CLI. It sets up the source and model, then processes the
|
|
231
|
+
inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
|
|
226
232
|
generator without storing results.
|
|
227
233
|
|
|
228
234
|
Args:
|
|
229
|
-
source (str | Path |
|
|
235
|
+
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
|
230
236
|
Source for inference.
|
|
231
|
-
model (str | Path | torch.nn.Module
|
|
237
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
|
232
238
|
|
|
233
|
-
|
|
239
|
+
Notes:
|
|
234
240
|
Do not modify this function or remove the generator. The generator ensures that no outputs are
|
|
235
241
|
accumulated in memory, which is critical for preventing memory issues during long-running predictions.
|
|
236
242
|
"""
|
|
@@ -239,23 +245,13 @@ class BasePredictor:
|
|
|
239
245
|
pass
|
|
240
246
|
|
|
241
247
|
def setup_source(self, source):
|
|
242
|
-
"""
|
|
243
|
-
Set up source and inference mode.
|
|
248
|
+
"""Set up source and inference mode.
|
|
244
249
|
|
|
245
250
|
Args:
|
|
246
|
-
source (str | Path |
|
|
247
|
-
|
|
251
|
+
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor): Source for
|
|
252
|
+
inference.
|
|
248
253
|
"""
|
|
249
254
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
|
250
|
-
self.transforms = (
|
|
251
|
-
getattr(
|
|
252
|
-
self.model.model,
|
|
253
|
-
"transforms",
|
|
254
|
-
classify_transforms(self.imgsz[0]),
|
|
255
|
-
)
|
|
256
|
-
if self.args.task == "classify"
|
|
257
|
-
else None
|
|
258
|
-
)
|
|
259
255
|
self.dataset = load_inference_source(
|
|
260
256
|
source=source,
|
|
261
257
|
batch=self.args.batch,
|
|
@@ -264,24 +260,27 @@ class BasePredictor:
|
|
|
264
260
|
channels=getattr(self.model, "ch", 3),
|
|
265
261
|
)
|
|
266
262
|
self.source_type = self.dataset.source_type
|
|
267
|
-
|
|
263
|
+
long_sequence = (
|
|
268
264
|
self.source_type.stream
|
|
269
265
|
or self.source_type.screenshot
|
|
270
266
|
or len(self.dataset) > 1000 # many images
|
|
271
267
|
or any(getattr(self.dataset, "video_flag", [False]))
|
|
272
|
-
)
|
|
273
|
-
|
|
268
|
+
)
|
|
269
|
+
if long_sequence:
|
|
270
|
+
import torchvision # noqa (import here triggers torchvision NMS use in nms.py)
|
|
271
|
+
|
|
272
|
+
if not getattr(self, "stream", True): # videos
|
|
273
|
+
LOGGER.warning(STREAM_WARNING)
|
|
274
274
|
self.vid_writer = {}
|
|
275
275
|
|
|
276
276
|
@smart_inference_mode()
|
|
277
277
|
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
|
278
|
-
"""
|
|
279
|
-
Stream real-time inference on camera feed and save results to file.
|
|
278
|
+
"""Stream real-time inference on camera feed and save results to file.
|
|
280
279
|
|
|
281
280
|
Args:
|
|
282
|
-
source (str | Path |
|
|
281
|
+
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
|
283
282
|
Source for inference.
|
|
284
|
-
model (str | Path | torch.nn.Module
|
|
283
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
|
285
284
|
*args (Any): Additional arguments for the inference method.
|
|
286
285
|
**kwargs (Any): Additional keyword arguments for the inference method.
|
|
287
286
|
|
|
@@ -339,15 +338,18 @@ class BasePredictor:
|
|
|
339
338
|
|
|
340
339
|
# Visualize, save, write results
|
|
341
340
|
n = len(im0s)
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
341
|
+
try:
|
|
342
|
+
for i in range(n):
|
|
343
|
+
self.seen += 1
|
|
344
|
+
self.results[i].speed = {
|
|
345
|
+
"preprocess": profilers[0].dt * 1e3 / n,
|
|
346
|
+
"inference": profilers[1].dt * 1e3 / n,
|
|
347
|
+
"postprocess": profilers[2].dt * 1e3 / n,
|
|
348
|
+
}
|
|
349
|
+
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
|
350
|
+
s[i] += self.write_results(i, Path(paths[i]), im, s)
|
|
351
|
+
except StopIteration:
|
|
352
|
+
break
|
|
351
353
|
|
|
352
354
|
# Print batch results
|
|
353
355
|
if self.args.verbose:
|
|
@@ -361,6 +363,9 @@ class BasePredictor:
|
|
|
361
363
|
if isinstance(v, cv2.VideoWriter):
|
|
362
364
|
v.release()
|
|
363
365
|
|
|
366
|
+
if self.args.show:
|
|
367
|
+
cv2.destroyAllWindows() # close any open windows
|
|
368
|
+
|
|
364
369
|
# Print final results
|
|
365
370
|
if self.args.verbose and self.seen:
|
|
366
371
|
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
|
@@ -374,38 +379,38 @@ class BasePredictor:
|
|
|
374
379
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
|
375
380
|
self.run_callbacks("on_predict_end")
|
|
376
381
|
|
|
377
|
-
def setup_model(self, model, verbose=True):
|
|
378
|
-
"""
|
|
379
|
-
Initialize YOLO model with given parameters and set it to evaluation mode.
|
|
382
|
+
def setup_model(self, model, verbose: bool = True):
|
|
383
|
+
"""Initialize YOLO model with given parameters and set it to evaluation mode.
|
|
380
384
|
|
|
381
385
|
Args:
|
|
382
|
-
model (str | Path | torch.nn.Module
|
|
386
|
+
model (str | Path | torch.nn.Module, optional): Model to load or use.
|
|
383
387
|
verbose (bool): Whether to print verbose output.
|
|
384
388
|
"""
|
|
385
389
|
self.model = AutoBackend(
|
|
386
|
-
|
|
390
|
+
model=model or self.args.model,
|
|
387
391
|
device=select_device(self.args.device, verbose=verbose),
|
|
388
392
|
dnn=self.args.dnn,
|
|
389
393
|
data=self.args.data,
|
|
390
394
|
fp16=self.args.half,
|
|
391
|
-
batch=self.args.batch,
|
|
392
395
|
fuse=True,
|
|
393
396
|
verbose=verbose,
|
|
394
397
|
)
|
|
395
398
|
|
|
396
399
|
self.device = self.model.device # update device
|
|
397
400
|
self.args.half = self.model.fp16 # update half
|
|
401
|
+
if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
|
|
402
|
+
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
|
|
398
403
|
self.model.eval()
|
|
404
|
+
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
|
399
405
|
|
|
400
|
-
def write_results(self, i, p, im, s):
|
|
401
|
-
"""
|
|
402
|
-
Write inference results to a file or directory.
|
|
406
|
+
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
|
|
407
|
+
"""Write inference results to a file or directory.
|
|
403
408
|
|
|
404
409
|
Args:
|
|
405
410
|
i (int): Index of the current image in the batch.
|
|
406
411
|
p (Path): Path to the current image.
|
|
407
412
|
im (torch.Tensor): Preprocessed image tensor.
|
|
408
|
-
s (
|
|
413
|
+
s (list[str]): List of result strings.
|
|
409
414
|
|
|
410
415
|
Returns:
|
|
411
416
|
(str): String with result information.
|
|
@@ -444,16 +449,15 @@ class BasePredictor:
|
|
|
444
449
|
if self.args.show:
|
|
445
450
|
self.show(str(p))
|
|
446
451
|
if self.args.save:
|
|
447
|
-
self.save_predicted_images(
|
|
452
|
+
self.save_predicted_images(self.save_dir / p.name, frame)
|
|
448
453
|
|
|
449
454
|
return string
|
|
450
455
|
|
|
451
|
-
def save_predicted_images(self, save_path
|
|
452
|
-
"""
|
|
453
|
-
Save video predictions as mp4 or images as jpg at specified path.
|
|
456
|
+
def save_predicted_images(self, save_path: Path, frame: int = 0):
|
|
457
|
+
"""Save video predictions as mp4 or images as jpg at specified path.
|
|
454
458
|
|
|
455
459
|
Args:
|
|
456
|
-
save_path (
|
|
460
|
+
save_path (Path): Path to save the results.
|
|
457
461
|
frame (int): Frame number for video mode.
|
|
458
462
|
"""
|
|
459
463
|
im = self.plotted_img
|
|
@@ -461,7 +465,7 @@ class BasePredictor:
|
|
|
461
465
|
# Save videos and streams
|
|
462
466
|
if self.dataset.mode in {"stream", "video"}:
|
|
463
467
|
fps = self.dataset.fps if self.dataset.mode == "video" else 30
|
|
464
|
-
frames_path = f"{save_path.
|
|
468
|
+
frames_path = self.save_dir / f"{save_path.stem}_frames" # save frames to a separate directory
|
|
465
469
|
if save_path not in self.vid_writer: # new video
|
|
466
470
|
if self.args.save_frames:
|
|
467
471
|
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
|
@@ -476,13 +480,13 @@ class BasePredictor:
|
|
|
476
480
|
# Save video
|
|
477
481
|
self.vid_writer[save_path].write(im)
|
|
478
482
|
if self.args.save_frames:
|
|
479
|
-
cv2.imwrite(f"{frames_path}{frame}.jpg", im)
|
|
483
|
+
cv2.imwrite(f"{frames_path}/{save_path.stem}_{frame}.jpg", im)
|
|
480
484
|
|
|
481
485
|
# Save images
|
|
482
486
|
else:
|
|
483
|
-
cv2.imwrite(str(
|
|
487
|
+
cv2.imwrite(str(save_path.with_suffix(".jpg")), im) # save to JPG for best support
|
|
484
488
|
|
|
485
|
-
def show(self, p=""):
|
|
489
|
+
def show(self, p: str = ""):
|
|
486
490
|
"""Display an image in a window."""
|
|
487
491
|
im = self.plotted_img
|
|
488
492
|
if platform.system() == "Linux" and p not in self.windows:
|
|
@@ -490,13 +494,14 @@ class BasePredictor:
|
|
|
490
494
|
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
|
491
495
|
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
|
|
492
496
|
cv2.imshow(p, im)
|
|
493
|
-
cv2.waitKey(300 if self.dataset.mode == "image" else 1) #
|
|
497
|
+
if cv2.waitKey(300 if self.dataset.mode == "image" else 1) & 0xFF == ord("q"): # 300ms if image; else 1ms
|
|
498
|
+
raise StopIteration
|
|
494
499
|
|
|
495
500
|
def run_callbacks(self, event: str):
|
|
496
501
|
"""Run all registered callbacks for a specific event."""
|
|
497
502
|
for callback in self.callbacks.get(event, []):
|
|
498
503
|
callback(self)
|
|
499
504
|
|
|
500
|
-
def add_callback(self, event: str, func):
|
|
505
|
+
def add_callback(self, event: str, func: callable):
|
|
501
506
|
"""Add a callback function for a specific event."""
|
|
502
507
|
self.callbacks[event].append(func)
|