dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,502 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
4
|
+
|
5
|
+
Usage - sources:
|
6
|
+
$ yolo mode=predict model=yolo11n.pt source=0 # webcam
|
7
|
+
img.jpg # image
|
8
|
+
vid.mp4 # video
|
9
|
+
screen # screenshot
|
10
|
+
path/ # directory
|
11
|
+
list.txt # list of images
|
12
|
+
list.streams # list of streams
|
13
|
+
'path/*.jpg' # glob
|
14
|
+
'https://youtu.be/LNwODJXcvt4' # YouTube
|
15
|
+
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
|
16
|
+
|
17
|
+
Usage - formats:
|
18
|
+
$ yolo mode=predict model=yolo11n.pt # PyTorch
|
19
|
+
yolo11n.torchscript # TorchScript
|
20
|
+
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
21
|
+
yolo11n_openvino_model # OpenVINO
|
22
|
+
yolo11n.engine # TensorRT
|
23
|
+
yolo11n.mlpackage # CoreML (macOS-only)
|
24
|
+
yolo11n_saved_model # TensorFlow SavedModel
|
25
|
+
yolo11n.pb # TensorFlow GraphDef
|
26
|
+
yolo11n.tflite # TensorFlow Lite
|
27
|
+
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
28
|
+
yolo11n_paddle_model # PaddlePaddle
|
29
|
+
yolo11n.mnn # MNN
|
30
|
+
yolo11n_ncnn_model # NCNN
|
31
|
+
yolo11n_imx_model # Sony IMX
|
32
|
+
yolo11n_rknn_model # Rockchip RKNN
|
33
|
+
"""
|
34
|
+
|
35
|
+
import platform
|
36
|
+
import re
|
37
|
+
import threading
|
38
|
+
from pathlib import Path
|
39
|
+
|
40
|
+
import cv2
|
41
|
+
import numpy as np
|
42
|
+
import torch
|
43
|
+
|
44
|
+
from ultralytics.cfg import get_cfg, get_save_dir
|
45
|
+
from ultralytics.data import load_inference_source
|
46
|
+
from ultralytics.data.augment import LetterBox, classify_transforms
|
47
|
+
from ultralytics.nn.autobackend import AutoBackend
|
48
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
49
|
+
from ultralytics.utils.checks import check_imgsz, check_imshow
|
50
|
+
from ultralytics.utils.files import increment_path
|
51
|
+
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
52
|
+
|
53
|
+
STREAM_WARNING = """
|
54
|
+
inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
|
55
|
+
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
|
56
|
+
|
57
|
+
Example:
|
58
|
+
results = model(source=..., stream=True) # generator of Results objects
|
59
|
+
for r in results:
|
60
|
+
boxes = r.boxes # Boxes object for bbox outputs
|
61
|
+
masks = r.masks # Masks object for segment masks outputs
|
62
|
+
probs = r.probs # Class probabilities for classification outputs
|
63
|
+
"""
|
64
|
+
|
65
|
+
|
66
|
+
class BasePredictor:
|
67
|
+
"""
|
68
|
+
A base class for creating predictors.
|
69
|
+
|
70
|
+
This class provides the foundation for prediction functionality, handling model setup, inference,
|
71
|
+
and result processing across various input sources.
|
72
|
+
|
73
|
+
Attributes:
|
74
|
+
args (SimpleNamespace): Configuration for the predictor.
|
75
|
+
save_dir (Path): Directory to save results.
|
76
|
+
done_warmup (bool): Whether the predictor has finished setup.
|
77
|
+
model (torch.nn.Module): Model used for prediction.
|
78
|
+
data (dict): Data configuration.
|
79
|
+
device (torch.device): Device used for prediction.
|
80
|
+
dataset (Dataset): Dataset used for prediction.
|
81
|
+
vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
|
82
|
+
plotted_img (numpy.ndarray): Last plotted image.
|
83
|
+
source_type (SimpleNamespace): Type of input source.
|
84
|
+
seen (int): Number of images processed.
|
85
|
+
windows (list): List of window names for visualization.
|
86
|
+
batch (tuple): Current batch data.
|
87
|
+
results (list): Current batch results.
|
88
|
+
transforms (callable): Image transforms for classification.
|
89
|
+
callbacks (dict): Callback functions for different events.
|
90
|
+
txt_path (Path): Path to save text results.
|
91
|
+
_lock (threading.Lock): Lock for thread-safe inference.
|
92
|
+
|
93
|
+
Methods:
|
94
|
+
preprocess: Prepare input image before inference.
|
95
|
+
inference: Run inference on a given image.
|
96
|
+
postprocess: Process raw predictions into structured results.
|
97
|
+
predict_cli: Run prediction for command line interface.
|
98
|
+
setup_source: Set up input source and inference mode.
|
99
|
+
stream_inference: Stream inference on input source.
|
100
|
+
setup_model: Initialize and configure the model.
|
101
|
+
write_results: Write inference results to files.
|
102
|
+
save_predicted_images: Save prediction visualizations.
|
103
|
+
show: Display results in a window.
|
104
|
+
run_callbacks: Execute registered callbacks for an event.
|
105
|
+
add_callback: Register a new callback function.
|
106
|
+
"""
|
107
|
+
|
108
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
109
|
+
"""
|
110
|
+
Initialize the BasePredictor class.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
114
|
+
overrides (dict | None): Configuration overrides.
|
115
|
+
_callbacks (dict | None): Dictionary of callback functions.
|
116
|
+
"""
|
117
|
+
self.args = get_cfg(cfg, overrides)
|
118
|
+
self.save_dir = get_save_dir(self.args)
|
119
|
+
if self.args.conf is None:
|
120
|
+
self.args.conf = 0.25 # default conf=0.25
|
121
|
+
self.done_warmup = False
|
122
|
+
if self.args.show:
|
123
|
+
self.args.show = check_imshow(warn=True)
|
124
|
+
|
125
|
+
# Usable if setup is done
|
126
|
+
self.model = None
|
127
|
+
self.data = self.args.data # data_dict
|
128
|
+
self.imgsz = None
|
129
|
+
self.device = None
|
130
|
+
self.dataset = None
|
131
|
+
self.vid_writer = {} # dict of {save_path: video_writer, ...}
|
132
|
+
self.plotted_img = None
|
133
|
+
self.source_type = None
|
134
|
+
self.seen = 0
|
135
|
+
self.windows = []
|
136
|
+
self.batch = None
|
137
|
+
self.results = None
|
138
|
+
self.transforms = None
|
139
|
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
140
|
+
self.txt_path = None
|
141
|
+
self._lock = threading.Lock() # for automatic thread-safe inference
|
142
|
+
callbacks.add_integration_callbacks(self)
|
143
|
+
|
144
|
+
def preprocess(self, im):
|
145
|
+
"""
|
146
|
+
Prepares input image before inference.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
150
|
+
"""
|
151
|
+
not_tensor = not isinstance(im, torch.Tensor)
|
152
|
+
if not_tensor:
|
153
|
+
im = np.stack(self.pre_transform(im))
|
154
|
+
if im.shape[-1] == 3:
|
155
|
+
im = im[..., ::-1] # BGR to RGB
|
156
|
+
im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
|
157
|
+
im = np.ascontiguousarray(im) # contiguous
|
158
|
+
im = torch.from_numpy(im)
|
159
|
+
|
160
|
+
im = im.to(self.device)
|
161
|
+
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
162
|
+
if not_tensor:
|
163
|
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
164
|
+
return im
|
165
|
+
|
166
|
+
def inference(self, im, *args, **kwargs):
|
167
|
+
"""Run inference on a given image using the specified model and arguments."""
|
168
|
+
visualize = (
|
169
|
+
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
170
|
+
if self.args.visualize and (not self.source_type.tensor)
|
171
|
+
else False
|
172
|
+
)
|
173
|
+
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
174
|
+
|
175
|
+
def pre_transform(self, im):
|
176
|
+
"""
|
177
|
+
Pre-transform input image before inference.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
(List[np.ndarray]): A list of transformed images.
|
184
|
+
"""
|
185
|
+
same_shapes = len({x.shape for x in im}) == 1
|
186
|
+
letterbox = LetterBox(
|
187
|
+
self.imgsz,
|
188
|
+
auto=same_shapes
|
189
|
+
and self.args.rect
|
190
|
+
and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
|
191
|
+
stride=self.model.stride,
|
192
|
+
)
|
193
|
+
return [letterbox(image=x) for x in im]
|
194
|
+
|
195
|
+
def postprocess(self, preds, img, orig_imgs):
|
196
|
+
"""Post-process predictions for an image and return them."""
|
197
|
+
return preds
|
198
|
+
|
199
|
+
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
200
|
+
"""
|
201
|
+
Perform inference on an image or stream.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
205
|
+
Source for inference.
|
206
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
207
|
+
stream (bool): Whether to stream the inference results. If True, returns a generator.
|
208
|
+
*args (Any): Additional arguments for the inference method.
|
209
|
+
**kwargs (Any): Additional keyword arguments for the inference method.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
(List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
|
213
|
+
"""
|
214
|
+
self.stream = stream
|
215
|
+
if stream:
|
216
|
+
return self.stream_inference(source, model, *args, **kwargs)
|
217
|
+
else:
|
218
|
+
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
219
|
+
|
220
|
+
def predict_cli(self, source=None, model=None):
|
221
|
+
"""
|
222
|
+
Method used for Command Line Interface (CLI) prediction.
|
223
|
+
|
224
|
+
This function is designed to run predictions using the CLI. It sets up the source and model, then processes
|
225
|
+
the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
|
226
|
+
generator without storing results.
|
227
|
+
|
228
|
+
Args:
|
229
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
230
|
+
Source for inference.
|
231
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
232
|
+
|
233
|
+
Note:
|
234
|
+
Do not modify this function or remove the generator. The generator ensures that no outputs are
|
235
|
+
accumulated in memory, which is critical for preventing memory issues during long-running predictions.
|
236
|
+
"""
|
237
|
+
gen = self.stream_inference(source, model)
|
238
|
+
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
|
239
|
+
pass
|
240
|
+
|
241
|
+
def setup_source(self, source):
|
242
|
+
"""
|
243
|
+
Set up source and inference mode.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
|
247
|
+
Source for inference.
|
248
|
+
"""
|
249
|
+
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
250
|
+
self.transforms = (
|
251
|
+
getattr(
|
252
|
+
self.model.model,
|
253
|
+
"transforms",
|
254
|
+
classify_transforms(self.imgsz[0]),
|
255
|
+
)
|
256
|
+
if self.args.task == "classify"
|
257
|
+
else None
|
258
|
+
)
|
259
|
+
self.dataset = load_inference_source(
|
260
|
+
source=source,
|
261
|
+
batch=self.args.batch,
|
262
|
+
vid_stride=self.args.vid_stride,
|
263
|
+
buffer=self.args.stream_buffer,
|
264
|
+
channels=getattr(self.model, "ch", 3),
|
265
|
+
)
|
266
|
+
self.source_type = self.dataset.source_type
|
267
|
+
if not getattr(self, "stream", True) and (
|
268
|
+
self.source_type.stream
|
269
|
+
or self.source_type.screenshot
|
270
|
+
or len(self.dataset) > 1000 # many images
|
271
|
+
or any(getattr(self.dataset, "video_flag", [False]))
|
272
|
+
): # videos
|
273
|
+
LOGGER.warning(STREAM_WARNING)
|
274
|
+
self.vid_writer = {}
|
275
|
+
|
276
|
+
@smart_inference_mode()
|
277
|
+
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
278
|
+
"""
|
279
|
+
Stream real-time inference on camera feed and save results to file.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
283
|
+
Source for inference.
|
284
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
285
|
+
*args (Any): Additional arguments for the inference method.
|
286
|
+
**kwargs (Any): Additional keyword arguments for the inference method.
|
287
|
+
|
288
|
+
Yields:
|
289
|
+
(ultralytics.engine.results.Results): Results objects.
|
290
|
+
"""
|
291
|
+
if self.args.verbose:
|
292
|
+
LOGGER.info("")
|
293
|
+
|
294
|
+
# Setup model
|
295
|
+
if not self.model:
|
296
|
+
self.setup_model(model)
|
297
|
+
|
298
|
+
with self._lock: # for thread-safe inference
|
299
|
+
# Setup source every time predict is called
|
300
|
+
self.setup_source(source if source is not None else self.args.source)
|
301
|
+
|
302
|
+
# Check if save_dir/ label file exists
|
303
|
+
if self.args.save or self.args.save_txt:
|
304
|
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
305
|
+
|
306
|
+
# Warmup model
|
307
|
+
if not self.done_warmup:
|
308
|
+
self.model.warmup(
|
309
|
+
imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz)
|
310
|
+
)
|
311
|
+
self.done_warmup = True
|
312
|
+
|
313
|
+
self.seen, self.windows, self.batch = 0, [], None
|
314
|
+
profilers = (
|
315
|
+
ops.Profile(device=self.device),
|
316
|
+
ops.Profile(device=self.device),
|
317
|
+
ops.Profile(device=self.device),
|
318
|
+
)
|
319
|
+
self.run_callbacks("on_predict_start")
|
320
|
+
for self.batch in self.dataset:
|
321
|
+
self.run_callbacks("on_predict_batch_start")
|
322
|
+
paths, im0s, s = self.batch
|
323
|
+
|
324
|
+
# Preprocess
|
325
|
+
with profilers[0]:
|
326
|
+
im = self.preprocess(im0s)
|
327
|
+
|
328
|
+
# Inference
|
329
|
+
with profilers[1]:
|
330
|
+
preds = self.inference(im, *args, **kwargs)
|
331
|
+
if self.args.embed:
|
332
|
+
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
333
|
+
continue
|
334
|
+
|
335
|
+
# Postprocess
|
336
|
+
with profilers[2]:
|
337
|
+
self.results = self.postprocess(preds, im, im0s)
|
338
|
+
self.run_callbacks("on_predict_postprocess_end")
|
339
|
+
|
340
|
+
# Visualize, save, write results
|
341
|
+
n = len(im0s)
|
342
|
+
for i in range(n):
|
343
|
+
self.seen += 1
|
344
|
+
self.results[i].speed = {
|
345
|
+
"preprocess": profilers[0].dt * 1e3 / n,
|
346
|
+
"inference": profilers[1].dt * 1e3 / n,
|
347
|
+
"postprocess": profilers[2].dt * 1e3 / n,
|
348
|
+
}
|
349
|
+
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
350
|
+
s[i] += self.write_results(i, Path(paths[i]), im, s)
|
351
|
+
|
352
|
+
# Print batch results
|
353
|
+
if self.args.verbose:
|
354
|
+
LOGGER.info("\n".join(s))
|
355
|
+
|
356
|
+
self.run_callbacks("on_predict_batch_end")
|
357
|
+
yield from self.results
|
358
|
+
|
359
|
+
# Release assets
|
360
|
+
for v in self.vid_writer.values():
|
361
|
+
if isinstance(v, cv2.VideoWriter):
|
362
|
+
v.release()
|
363
|
+
|
364
|
+
# Print final results
|
365
|
+
if self.args.verbose and self.seen:
|
366
|
+
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
367
|
+
LOGGER.info(
|
368
|
+
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
369
|
+
f"{(min(self.args.batch, self.seen), getattr(self.model, 'ch', 3), *im.shape[2:])}" % t
|
370
|
+
)
|
371
|
+
if self.args.save or self.args.save_txt or self.args.save_crop:
|
372
|
+
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
373
|
+
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
374
|
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
375
|
+
self.run_callbacks("on_predict_end")
|
376
|
+
|
377
|
+
def setup_model(self, model, verbose=True):
|
378
|
+
"""
|
379
|
+
Initialize YOLO model with given parameters and set it to evaluation mode.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
model (str | Path | torch.nn.Module | None): Model to load or use.
|
383
|
+
verbose (bool): Whether to print verbose output.
|
384
|
+
"""
|
385
|
+
self.model = AutoBackend(
|
386
|
+
weights=model or self.args.model,
|
387
|
+
device=select_device(self.args.device, verbose=verbose),
|
388
|
+
dnn=self.args.dnn,
|
389
|
+
data=self.args.data,
|
390
|
+
fp16=self.args.half,
|
391
|
+
batch=self.args.batch,
|
392
|
+
fuse=True,
|
393
|
+
verbose=verbose,
|
394
|
+
)
|
395
|
+
|
396
|
+
self.device = self.model.device # update device
|
397
|
+
self.args.half = self.model.fp16 # update half
|
398
|
+
self.model.eval()
|
399
|
+
|
400
|
+
def write_results(self, i, p, im, s):
|
401
|
+
"""
|
402
|
+
Write inference results to a file or directory.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
i (int): Index of the current image in the batch.
|
406
|
+
p (Path): Path to the current image.
|
407
|
+
im (torch.Tensor): Preprocessed image tensor.
|
408
|
+
s (List[str]): List of result strings.
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
(str): String with result information.
|
412
|
+
"""
|
413
|
+
string = "" # print string
|
414
|
+
if len(im.shape) == 3:
|
415
|
+
im = im[None] # expand for batch dim
|
416
|
+
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
417
|
+
string += f"{i}: "
|
418
|
+
frame = self.dataset.count
|
419
|
+
else:
|
420
|
+
match = re.search(r"frame (\d+)/", s[i])
|
421
|
+
frame = int(match[1]) if match else None # 0 if frame undetermined
|
422
|
+
|
423
|
+
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
|
424
|
+
string += "{:g}x{:g} ".format(*im.shape[2:])
|
425
|
+
result = self.results[i]
|
426
|
+
result.save_dir = self.save_dir.__str__() # used in other locations
|
427
|
+
string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
|
428
|
+
|
429
|
+
# Add predictions to image
|
430
|
+
if self.args.save or self.args.show:
|
431
|
+
self.plotted_img = result.plot(
|
432
|
+
line_width=self.args.line_width,
|
433
|
+
boxes=self.args.show_boxes,
|
434
|
+
conf=self.args.show_conf,
|
435
|
+
labels=self.args.show_labels,
|
436
|
+
im_gpu=None if self.args.retina_masks else im[i],
|
437
|
+
)
|
438
|
+
|
439
|
+
# Save results
|
440
|
+
if self.args.save_txt:
|
441
|
+
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
442
|
+
if self.args.save_crop:
|
443
|
+
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
|
444
|
+
if self.args.show:
|
445
|
+
self.show(str(p))
|
446
|
+
if self.args.save:
|
447
|
+
self.save_predicted_images(str(self.save_dir / p.name), frame)
|
448
|
+
|
449
|
+
return string
|
450
|
+
|
451
|
+
def save_predicted_images(self, save_path="", frame=0):
|
452
|
+
"""
|
453
|
+
Save video predictions as mp4 or images as jpg at specified path.
|
454
|
+
|
455
|
+
Args:
|
456
|
+
save_path (str): Path to save the results.
|
457
|
+
frame (int): Frame number for video mode.
|
458
|
+
"""
|
459
|
+
im = self.plotted_img
|
460
|
+
|
461
|
+
# Save videos and streams
|
462
|
+
if self.dataset.mode in {"stream", "video"}:
|
463
|
+
fps = self.dataset.fps if self.dataset.mode == "video" else 30
|
464
|
+
frames_path = f"{save_path.split('.', 1)[0]}_frames/"
|
465
|
+
if save_path not in self.vid_writer: # new video
|
466
|
+
if self.args.save_frames:
|
467
|
+
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
468
|
+
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
469
|
+
self.vid_writer[save_path] = cv2.VideoWriter(
|
470
|
+
filename=str(Path(save_path).with_suffix(suffix)),
|
471
|
+
fourcc=cv2.VideoWriter_fourcc(*fourcc),
|
472
|
+
fps=fps, # integer required, floats produce error in MP4 codec
|
473
|
+
frameSize=(im.shape[1], im.shape[0]), # (width, height)
|
474
|
+
)
|
475
|
+
|
476
|
+
# Save video
|
477
|
+
self.vid_writer[save_path].write(im)
|
478
|
+
if self.args.save_frames:
|
479
|
+
cv2.imwrite(f"{frames_path}{frame}.jpg", im)
|
480
|
+
|
481
|
+
# Save images
|
482
|
+
else:
|
483
|
+
cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
|
484
|
+
|
485
|
+
def show(self, p=""):
|
486
|
+
"""Display an image in a window."""
|
487
|
+
im = self.plotted_img
|
488
|
+
if platform.system() == "Linux" and p not in self.windows:
|
489
|
+
self.windows.append(p)
|
490
|
+
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
491
|
+
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
|
492
|
+
cv2.imshow(p, im)
|
493
|
+
cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
|
494
|
+
|
495
|
+
def run_callbacks(self, event: str):
|
496
|
+
"""Run all registered callbacks for a specific event."""
|
497
|
+
for callback in self.callbacks.get(event, []):
|
498
|
+
callback(self)
|
499
|
+
|
500
|
+
def add_callback(self, event: str, func):
|
501
|
+
"""Add a callback function for a specific event."""
|
502
|
+
self.callbacks[event].append(func)
|