ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/engine/exporter.py
CHANGED
@@ -222,11 +222,53 @@ def arange_patch(args):
|
|
222
222
|
|
223
223
|
class Exporter:
|
224
224
|
"""
|
225
|
-
A class for exporting
|
225
|
+
A class for exporting YOLO models to various formats.
|
226
|
+
|
227
|
+
This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
|
228
|
+
TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
|
229
|
+
process for each supported format.
|
226
230
|
|
227
231
|
Attributes:
|
228
|
-
args (SimpleNamespace): Configuration for the exporter.
|
229
|
-
callbacks (
|
232
|
+
args (SimpleNamespace): Configuration arguments for the exporter.
|
233
|
+
callbacks (dict): Dictionary of callback functions for different export events.
|
234
|
+
im (torch.Tensor): Input tensor for model inference during export.
|
235
|
+
model (torch.nn.Module): The YOLO model to be exported.
|
236
|
+
file (Path): Path to the model file being exported.
|
237
|
+
output_shape (tuple): Shape of the model output tensor(s).
|
238
|
+
pretty_name (str): Formatted model name for display purposes.
|
239
|
+
metadata (dict): Model metadata including description, author, version, etc.
|
240
|
+
device (torch.device): Device on which the model is loaded.
|
241
|
+
imgsz (tuple): Input image size for the model.
|
242
|
+
|
243
|
+
Methods:
|
244
|
+
__call__: Main export method that handles the export process.
|
245
|
+
get_int8_calibration_dataloader: Build dataloader for INT8 calibration.
|
246
|
+
export_torchscript: Export model to TorchScript format.
|
247
|
+
export_onnx: Export model to ONNX format.
|
248
|
+
export_openvino: Export model to OpenVINO format.
|
249
|
+
export_paddle: Export model to PaddlePaddle format.
|
250
|
+
export_mnn: Export model to MNN format.
|
251
|
+
export_ncnn: Export model to NCNN format.
|
252
|
+
export_coreml: Export model to CoreML format.
|
253
|
+
export_engine: Export model to TensorRT format.
|
254
|
+
export_saved_model: Export model to TensorFlow SavedModel format.
|
255
|
+
export_pb: Export model to TensorFlow GraphDef format.
|
256
|
+
export_tflite: Export model to TensorFlow Lite format.
|
257
|
+
export_edgetpu: Export model to Edge TPU format.
|
258
|
+
export_tfjs: Export model to TensorFlow.js format.
|
259
|
+
export_rknn: Export model to RKNN format.
|
260
|
+
export_imx: Export model to IMX format.
|
261
|
+
|
262
|
+
Examples:
|
263
|
+
Export a YOLOv8 model to ONNX format
|
264
|
+
>>> from ultralytics.engine.exporter import Exporter
|
265
|
+
>>> exporter = Exporter()
|
266
|
+
>>> exporter(model="yolov8n.pt") # exports to yolov8n.onnx
|
267
|
+
|
268
|
+
Export with specific arguments
|
269
|
+
>>> args = {"format": "onnx", "dynamic": True, "half": True}
|
270
|
+
>>> exporter = Exporter(overrides=args)
|
271
|
+
>>> exporter(model="yolov8n.pt")
|
230
272
|
"""
|
231
273
|
|
232
274
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -536,7 +578,7 @@ class Exporter:
|
|
536
578
|
|
537
579
|
@try_export
|
538
580
|
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
539
|
-
"""YOLO TorchScript
|
581
|
+
"""Export YOLO model to TorchScript format."""
|
540
582
|
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
541
583
|
f = self.file.with_suffix(".torchscript")
|
542
584
|
|
@@ -553,7 +595,7 @@ class Exporter:
|
|
553
595
|
|
554
596
|
@try_export
|
555
597
|
def export_onnx(self, prefix=colorstr("ONNX:")):
|
556
|
-
"""YOLO ONNX
|
598
|
+
"""Export YOLO model to ONNX format."""
|
557
599
|
requirements = ["onnx>=1.12.0,<1.18.0"]
|
558
600
|
if self.args.simplify:
|
559
601
|
requirements += ["onnxslim>=0.1.53", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
|
@@ -612,7 +654,7 @@ class Exporter:
|
|
612
654
|
|
613
655
|
@try_export
|
614
656
|
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
615
|
-
"""YOLO OpenVINO
|
657
|
+
"""Export YOLO model to OpenVINO format."""
|
616
658
|
if MACOS:
|
617
659
|
msg = "OpenVINO error in macOS>=15.4 https://github.com/openvinotoolkit/openvino/issues/30023"
|
618
660
|
check_version(MACOS_VERSION, "<15.4", name="macOS ", hard=True, msg=msg)
|
@@ -689,7 +731,7 @@ class Exporter:
|
|
689
731
|
|
690
732
|
@try_export
|
691
733
|
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
692
|
-
"""YOLO
|
734
|
+
"""Export YOLO model to PaddlePaddle format."""
|
693
735
|
assert not IS_JETSON, "Jetson Paddle exports not supported yet"
|
694
736
|
check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle>=3.0.0", "x2paddle"))
|
695
737
|
import x2paddle # noqa
|
@@ -704,7 +746,7 @@ class Exporter:
|
|
704
746
|
|
705
747
|
@try_export
|
706
748
|
def export_mnn(self, prefix=colorstr("MNN:")):
|
707
|
-
"""YOLO MNN
|
749
|
+
"""Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN."""
|
708
750
|
f_onnx, _ = self.export_onnx() # get onnx model first
|
709
751
|
|
710
752
|
check_requirements("MNN>=2.9.6")
|
@@ -729,7 +771,7 @@ class Exporter:
|
|
729
771
|
|
730
772
|
@try_export
|
731
773
|
def export_ncnn(self, prefix=colorstr("NCNN:")):
|
732
|
-
"""YOLO NCNN
|
774
|
+
"""Export YOLO model to NCNN format using PNNX https://github.com/pnnx/pnnx."""
|
733
775
|
check_requirements("ncnn")
|
734
776
|
import ncnn # noqa
|
735
777
|
|
@@ -797,7 +839,7 @@ class Exporter:
|
|
797
839
|
|
798
840
|
@try_export
|
799
841
|
def export_coreml(self, prefix=colorstr("CoreML:")):
|
800
|
-
"""YOLO CoreML
|
842
|
+
"""Export YOLO model to CoreML format."""
|
801
843
|
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
802
844
|
check_requirements("coremltools>=8.0")
|
803
845
|
import coremltools as ct # noqa
|
@@ -876,7 +918,7 @@ class Exporter:
|
|
876
918
|
|
877
919
|
@try_export
|
878
920
|
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
|
879
|
-
"""YOLO TensorRT
|
921
|
+
"""Export YOLO model to TensorRT format https://developer.nvidia.com/tensorrt."""
|
880
922
|
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
881
923
|
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
|
882
924
|
|
@@ -912,7 +954,7 @@ class Exporter:
|
|
912
954
|
|
913
955
|
@try_export
|
914
956
|
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
915
|
-
"""YOLO TensorFlow SavedModel
|
957
|
+
"""Export YOLO model to TensorFlow SavedModel format."""
|
916
958
|
cuda = torch.cuda.is_available()
|
917
959
|
try:
|
918
960
|
import tensorflow as tf # noqa
|
@@ -1002,7 +1044,7 @@ class Exporter:
|
|
1002
1044
|
|
1003
1045
|
@try_export
|
1004
1046
|
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
1005
|
-
"""YOLO TensorFlow GraphDef *.pb
|
1047
|
+
"""Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
|
1006
1048
|
import tensorflow as tf # noqa
|
1007
1049
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
1008
1050
|
|
@@ -1018,7 +1060,7 @@ class Exporter:
|
|
1018
1060
|
|
1019
1061
|
@try_export
|
1020
1062
|
def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
|
1021
|
-
"""YOLO TensorFlow Lite
|
1063
|
+
"""Export YOLO model to TensorFlow Lite format."""
|
1022
1064
|
# BUG https://github.com/ultralytics/ultralytics/issues/13436
|
1023
1065
|
import tensorflow as tf # noqa
|
1024
1066
|
|
@@ -1034,7 +1076,7 @@ class Exporter:
|
|
1034
1076
|
|
1035
1077
|
@try_export
|
1036
1078
|
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
1037
|
-
"""YOLO Edge TPU
|
1079
|
+
"""Export YOLO model to Edge TPU format https://coral.ai/docs/edgetpu/models-intro/."""
|
1038
1080
|
cmd = "edgetpu_compiler --version"
|
1039
1081
|
help_url = "https://coral.ai/docs/edgetpu/compiler/"
|
1040
1082
|
assert LINUX, f"export only supported on Linux. See {help_url}"
|
@@ -1069,7 +1111,7 @@ class Exporter:
|
|
1069
1111
|
|
1070
1112
|
@try_export
|
1071
1113
|
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
1072
|
-
"""YOLO TensorFlow.js
|
1114
|
+
"""Export YOLO model to TensorFlow.js format."""
|
1073
1115
|
check_requirements("tensorflowjs")
|
1074
1116
|
import tensorflow as tf
|
1075
1117
|
import tensorflowjs as tfjs # noqa
|
@@ -1102,7 +1144,7 @@ class Exporter:
|
|
1102
1144
|
|
1103
1145
|
@try_export
|
1104
1146
|
def export_rknn(self, prefix=colorstr("RKNN:")):
|
1105
|
-
"""YOLO RKNN
|
1147
|
+
"""Export YOLO model to RKNN format."""
|
1106
1148
|
LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
|
1107
1149
|
|
1108
1150
|
check_requirements("rknn-toolkit2")
|
@@ -1129,7 +1171,7 @@ class Exporter:
|
|
1129
1171
|
|
1130
1172
|
@try_export
|
1131
1173
|
def export_imx(self, prefix=colorstr("IMX:")):
|
1132
|
-
"""YOLO IMX
|
1174
|
+
"""Export YOLO model to IMX format."""
|
1133
1175
|
gptq = False
|
1134
1176
|
assert LINUX, (
|
1135
1177
|
"export only supported on Linux. "
|
@@ -1212,6 +1254,8 @@ class Exporter:
|
|
1212
1254
|
)
|
1213
1255
|
|
1214
1256
|
class NMSWrapper(torch.nn.Module):
|
1257
|
+
"""Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
|
1258
|
+
|
1215
1259
|
def __init__(
|
1216
1260
|
self,
|
1217
1261
|
model: torch.nn.Module,
|
@@ -1220,13 +1264,13 @@ class Exporter:
|
|
1220
1264
|
max_detections: int = 300,
|
1221
1265
|
):
|
1222
1266
|
"""
|
1223
|
-
|
1267
|
+
Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
1224
1268
|
|
1225
1269
|
Args:
|
1226
|
-
model (nn.Module): Model instance.
|
1270
|
+
model (torch.nn.Module): Model instance.
|
1227
1271
|
score_threshold (float): Score threshold for non-maximum suppression.
|
1228
1272
|
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
1229
|
-
max_detections (
|
1273
|
+
max_detections (int): The number of detections to return.
|
1230
1274
|
"""
|
1231
1275
|
super().__init__()
|
1232
1276
|
self.model = model
|
@@ -1235,6 +1279,7 @@ class Exporter:
|
|
1235
1279
|
self.max_detections = max_detections
|
1236
1280
|
|
1237
1281
|
def forward(self, images):
|
1282
|
+
"""Forward pass with model inference and NMS post-processing."""
|
1238
1283
|
# model inference
|
1239
1284
|
outputs = self.model(images)
|
1240
1285
|
|
@@ -1289,7 +1334,7 @@ class Exporter:
|
|
1289
1334
|
zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
|
1290
1335
|
|
1291
1336
|
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
1292
|
-
"""
|
1337
|
+
"""Create CoreML pipeline with NMS for YOLO detection models."""
|
1293
1338
|
import coremltools as ct # noqa
|
1294
1339
|
|
1295
1340
|
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
@@ -1395,7 +1440,7 @@ class Exporter:
|
|
1395
1440
|
return model
|
1396
1441
|
|
1397
1442
|
def add_callback(self, event: str, callback):
|
1398
|
-
"""
|
1443
|
+
"""Append the given callback to the specified event."""
|
1399
1444
|
self.callbacks[event].append(callback)
|
1400
1445
|
|
1401
1446
|
def run_callbacks(self, event: str):
|
@@ -1408,7 +1453,13 @@ class IOSDetectModel(torch.nn.Module):
|
|
1408
1453
|
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
|
1409
1454
|
|
1410
1455
|
def __init__(self, model, im):
|
1411
|
-
"""
|
1456
|
+
"""
|
1457
|
+
Initialize the IOSDetectModel class with a YOLO model and example image.
|
1458
|
+
|
1459
|
+
Args:
|
1460
|
+
model (torch.nn.Module): The YOLO model to wrap.
|
1461
|
+
im (torch.Tensor): Example input tensor with shape (B, C, H, W).
|
1462
|
+
"""
|
1412
1463
|
super().__init__()
|
1413
1464
|
_, _, h, w = im.shape # batch, channel, height, width
|
1414
1465
|
self.model = model
|
@@ -1432,7 +1483,7 @@ class NMSModel(torch.nn.Module):
|
|
1432
1483
|
Initialize the NMSModel.
|
1433
1484
|
|
1434
1485
|
Args:
|
1435
|
-
model (torch.nn.
|
1486
|
+
model (torch.nn.Module): The model to wrap with NMS postprocessing.
|
1436
1487
|
args (Namespace): The export arguments.
|
1437
1488
|
"""
|
1438
1489
|
super().__init__()
|
@@ -1443,13 +1494,14 @@ class NMSModel(torch.nn.Module):
|
|
1443
1494
|
|
1444
1495
|
def forward(self, x):
|
1445
1496
|
"""
|
1446
|
-
|
1497
|
+
Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
1447
1498
|
|
1448
1499
|
Args:
|
1449
1500
|
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
|
1450
1501
|
|
1451
1502
|
Returns:
|
1452
|
-
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
|
1503
|
+
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
|
1504
|
+
number of detections after NMS.
|
1453
1505
|
"""
|
1454
1506
|
from functools import partial
|
1455
1507
|
|
ultralytics/engine/model.py
CHANGED
@@ -48,25 +48,25 @@ class Model(torch.nn.Module):
|
|
48
48
|
|
49
49
|
Methods:
|
50
50
|
__call__: Alias for the predict method, enabling the model instance to be callable.
|
51
|
-
_new:
|
52
|
-
_load:
|
53
|
-
_check_is_pytorch_model:
|
54
|
-
reset_weights:
|
55
|
-
load:
|
56
|
-
save:
|
57
|
-
info:
|
58
|
-
fuse:
|
59
|
-
predict:
|
60
|
-
track:
|
61
|
-
val:
|
62
|
-
benchmark:
|
63
|
-
export:
|
64
|
-
train:
|
65
|
-
tune:
|
66
|
-
_apply:
|
67
|
-
add_callback:
|
68
|
-
clear_callback:
|
69
|
-
reset_callbacks:
|
51
|
+
_new: Initialize a new model based on a configuration file.
|
52
|
+
_load: Load a model from a checkpoint file.
|
53
|
+
_check_is_pytorch_model: Ensure that the model is a PyTorch model.
|
54
|
+
reset_weights: Reset the model's weights to their initial state.
|
55
|
+
load: Load model weights from a specified file.
|
56
|
+
save: Save the current state of the model to a file.
|
57
|
+
info: Log or return information about the model.
|
58
|
+
fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.
|
59
|
+
predict: Perform object detection predictions.
|
60
|
+
track: Perform object tracking.
|
61
|
+
val: Validate the model on a dataset.
|
62
|
+
benchmark: Benchmark the model on various export formats.
|
63
|
+
export: Export the model to different formats.
|
64
|
+
train: Train the model on a dataset.
|
65
|
+
tune: Perform hyperparameter tuning.
|
66
|
+
_apply: Apply a function to the model's tensors.
|
67
|
+
add_callback: Add a callback function for an event.
|
68
|
+
clear_callback: Clear all callbacks for an event.
|
69
|
+
reset_callbacks: Reset all callbacks to their default functions.
|
70
70
|
|
71
71
|
Examples:
|
72
72
|
>>> from ultralytics import YOLO
|
@@ -94,7 +94,7 @@ class Model(torch.nn.Module):
|
|
94
94
|
Args:
|
95
95
|
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
|
96
96
|
model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
97
|
-
task (str
|
97
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
98
98
|
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
99
99
|
operations.
|
100
100
|
|
@@ -242,9 +242,9 @@ class Model(torch.nn.Module):
|
|
242
242
|
|
243
243
|
Args:
|
244
244
|
cfg (str): Path to the model configuration file in YAML format.
|
245
|
-
task (str
|
246
|
-
model (torch.nn.Module
|
247
|
-
a new one.
|
245
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
246
|
+
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
|
247
|
+
creating a new one.
|
248
248
|
verbose (bool): If True, displays model information during loading.
|
249
249
|
|
250
250
|
Raises:
|
@@ -276,7 +276,7 @@ class Model(torch.nn.Module):
|
|
276
276
|
|
277
277
|
Args:
|
278
278
|
weights (str): Path to the model weights file to be loaded.
|
279
|
-
task (str
|
279
|
+
task (str, optional): The task associated with the model. If None, it will be inferred from the model.
|
280
280
|
|
281
281
|
Raises:
|
282
282
|
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
@@ -367,7 +367,7 @@ class Model(torch.nn.Module):
|
|
367
367
|
name and shape and transfers them to the model.
|
368
368
|
|
369
369
|
Args:
|
370
|
-
weights (
|
370
|
+
weights (str | Path): Path to the weights file or a weights object.
|
371
371
|
|
372
372
|
Returns:
|
373
373
|
(Model): The instance of the class with loaded weights.
|
@@ -501,7 +501,7 @@ class Model(torch.nn.Module):
|
|
501
501
|
**kwargs: Any,
|
502
502
|
) -> List[Results]:
|
503
503
|
"""
|
504
|
-
|
504
|
+
Perform predictions on the given image source using the YOLO model.
|
505
505
|
|
506
506
|
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
507
507
|
It supports predictions with custom predictors or the default predictor method. The method handles different
|
@@ -512,7 +512,7 @@ class Model(torch.nn.Module):
|
|
512
512
|
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
|
513
513
|
images, numpy arrays, and torch tensors.
|
514
514
|
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
515
|
-
predictor (BasePredictor
|
515
|
+
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
516
516
|
If None, the method uses a default predictor.
|
517
517
|
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
518
518
|
|
@@ -562,14 +562,14 @@ class Model(torch.nn.Module):
|
|
562
562
|
**kwargs: Any,
|
563
563
|
) -> List[Results]:
|
564
564
|
"""
|
565
|
-
|
565
|
+
Conduct object tracking on the specified input source using the registered trackers.
|
566
566
|
|
567
567
|
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
568
568
|
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
569
569
|
The method registers trackers if not already present and can persist them between calls.
|
570
570
|
|
571
571
|
Args:
|
572
|
-
source (
|
572
|
+
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor, optional): Input source for object
|
573
573
|
tracking. Can be a file path, URL, or video stream.
|
574
574
|
stream (bool): If True, treats the input source as a continuous video stream.
|
575
575
|
persist (bool): If True, persists trackers between different calls to this method.
|
@@ -611,8 +611,8 @@ class Model(torch.nn.Module):
|
|
611
611
|
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
612
612
|
|
613
613
|
Args:
|
614
|
-
validator (ultralytics.engine.validator.BaseValidator
|
615
|
-
validating the model.
|
614
|
+
validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class
|
615
|
+
for validating the model.
|
616
616
|
**kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
|
617
617
|
|
618
618
|
Returns:
|
@@ -738,7 +738,7 @@ class Model(torch.nn.Module):
|
|
738
738
|
**kwargs: Any,
|
739
739
|
):
|
740
740
|
"""
|
741
|
-
|
741
|
+
Train the model using the specified dataset and training configuration.
|
742
742
|
|
743
743
|
This method facilitates model training with a range of customizable settings. It supports training with a
|
744
744
|
custom trainer or the default training approach. The method handles scenarios such as resuming training
|
@@ -749,7 +749,7 @@ class Model(torch.nn.Module):
|
|
749
749
|
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
750
750
|
|
751
751
|
Args:
|
752
|
-
trainer (BaseTrainer
|
752
|
+
trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
|
753
753
|
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
754
754
|
data (str): Path to dataset configuration file.
|
755
755
|
epochs (int): Number of training epochs.
|
@@ -810,7 +810,7 @@ class Model(torch.nn.Module):
|
|
810
810
|
**kwargs: Any,
|
811
811
|
):
|
812
812
|
"""
|
813
|
-
|
813
|
+
Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
814
814
|
|
815
815
|
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
|
816
816
|
When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
|
@@ -881,7 +881,7 @@ class Model(torch.nn.Module):
|
|
881
881
|
@property
|
882
882
|
def names(self) -> Dict[int, str]:
|
883
883
|
"""
|
884
|
-
|
884
|
+
Retrieve the class names associated with the loaded model.
|
885
885
|
|
886
886
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
887
887
|
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
@@ -935,7 +935,7 @@ class Model(torch.nn.Module):
|
|
935
935
|
@property
|
936
936
|
def transforms(self):
|
937
937
|
"""
|
938
|
-
|
938
|
+
Retrieve the transformations applied to the input data of the loaded model.
|
939
939
|
|
940
940
|
This property returns the transformations if they are defined in the model. The transforms
|
941
941
|
typically include preprocessing steps like resizing, normalization, and data augmentation
|
@@ -982,7 +982,7 @@ class Model(torch.nn.Module):
|
|
982
982
|
|
983
983
|
def clear_callback(self, event: str) -> None:
|
984
984
|
"""
|
985
|
-
|
985
|
+
Clear all callback functions registered for a specified event.
|
986
986
|
|
987
987
|
This method removes all custom and default callback functions associated with the given event.
|
988
988
|
It resets the callback list for the specified event to an empty list, effectively removing all
|
@@ -1062,7 +1062,7 @@ class Model(torch.nn.Module):
|
|
1062
1062
|
|
1063
1063
|
def _smart_load(self, key: str):
|
1064
1064
|
"""
|
1065
|
-
Intelligently
|
1065
|
+
Intelligently load the appropriate module based on the model task.
|
1066
1066
|
|
1067
1067
|
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
1068
1068
|
based on the current task of the model and the provided key. It uses the task_map dictionary to determine
|
@@ -1092,7 +1092,7 @@ class Model(torch.nn.Module):
|
|
1092
1092
|
@property
|
1093
1093
|
def task_map(self) -> dict:
|
1094
1094
|
"""
|
1095
|
-
|
1095
|
+
Provide a mapping from model tasks to corresponding classes for different modes.
|
1096
1096
|
|
1097
1097
|
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
1098
1098
|
to a nested dictionary. The nested dictionary contains mappings for different operational modes
|
ultralytics/engine/predictor.py
CHANGED
@@ -36,6 +36,7 @@ import platform
|
|
36
36
|
import re
|
37
37
|
import threading
|
38
38
|
from pathlib import Path
|
39
|
+
from typing import Any, Dict, List, Optional, Union
|
39
40
|
|
40
41
|
import cv2
|
41
42
|
import numpy as np
|
@@ -78,15 +79,15 @@ class BasePredictor:
|
|
78
79
|
data (dict): Data configuration.
|
79
80
|
device (torch.device): Device used for prediction.
|
80
81
|
dataset (Dataset): Dataset used for prediction.
|
81
|
-
vid_writer (
|
82
|
-
plotted_img (
|
82
|
+
vid_writer (Dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
|
83
|
+
plotted_img (np.ndarray): Last plotted image.
|
83
84
|
source_type (SimpleNamespace): Type of input source.
|
84
85
|
seen (int): Number of images processed.
|
85
|
-
windows (
|
86
|
+
windows (List[str]): List of window names for visualization.
|
86
87
|
batch (tuple): Current batch data.
|
87
|
-
results (
|
88
|
+
results (List[Any]): Current batch results.
|
88
89
|
transforms (callable): Image transforms for classification.
|
89
|
-
callbacks (
|
90
|
+
callbacks (Dict[str, List[callable]]): Callback functions for different events.
|
90
91
|
txt_path (Path): Path to save text results.
|
91
92
|
_lock (threading.Lock): Lock for thread-safe inference.
|
92
93
|
|
@@ -105,14 +106,19 @@ class BasePredictor:
|
|
105
106
|
add_callback: Register a new callback function.
|
106
107
|
"""
|
107
108
|
|
108
|
-
def __init__(
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
cfg=DEFAULT_CFG,
|
112
|
+
overrides: Optional[Dict[str, Any]] = None,
|
113
|
+
_callbacks: Optional[Dict[str, List[callable]]] = None,
|
114
|
+
):
|
109
115
|
"""
|
110
116
|
Initialize the BasePredictor class.
|
111
117
|
|
112
118
|
Args:
|
113
119
|
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
114
|
-
overrides (dict
|
115
|
-
_callbacks (dict
|
120
|
+
overrides (dict, optional): Configuration overrides.
|
121
|
+
_callbacks (dict, optional): Dictionary of callback functions.
|
116
122
|
"""
|
117
123
|
self.args = get_cfg(cfg, overrides)
|
118
124
|
self.save_dir = get_save_dir(self.args)
|
@@ -141,12 +147,15 @@ class BasePredictor:
|
|
141
147
|
self._lock = threading.Lock() # for automatic thread-safe inference
|
142
148
|
callbacks.add_integration_callbacks(self)
|
143
149
|
|
144
|
-
def preprocess(self, im):
|
150
|
+
def preprocess(self, im: Union[torch.Tensor, List[np.ndarray]]) -> torch.Tensor:
|
145
151
|
"""
|
146
|
-
|
152
|
+
Prepare input image before inference.
|
147
153
|
|
148
154
|
Args:
|
149
|
-
im (torch.Tensor | List
|
155
|
+
im (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
(torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
|
150
159
|
"""
|
151
160
|
not_tensor = not isinstance(im, torch.Tensor)
|
152
161
|
if not_tensor:
|
@@ -163,7 +172,7 @@ class BasePredictor:
|
|
163
172
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
164
173
|
return im
|
165
174
|
|
166
|
-
def inference(self, im, *args, **kwargs):
|
175
|
+
def inference(self, im: torch.Tensor, *args, **kwargs):
|
167
176
|
"""Run inference on a given image using the specified model and arguments."""
|
168
177
|
visualize = (
|
169
178
|
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
@@ -172,15 +181,15 @@ class BasePredictor:
|
|
172
181
|
)
|
173
182
|
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
174
183
|
|
175
|
-
def pre_transform(self, im):
|
184
|
+
def pre_transform(self, im: List[np.ndarray]) -> List[np.ndarray]:
|
176
185
|
"""
|
177
186
|
Pre-transform input image before inference.
|
178
187
|
|
179
188
|
Args:
|
180
|
-
im (List[np.ndarray]):
|
189
|
+
im (List[np.ndarray]): List of images with shape [(H, W, 3) x N].
|
181
190
|
|
182
191
|
Returns:
|
183
|
-
(List[np.ndarray]):
|
192
|
+
(List[np.ndarray]): List of transformed images.
|
184
193
|
"""
|
185
194
|
same_shapes = len({x.shape for x in im}) == 1
|
186
195
|
letterbox = LetterBox(
|
@@ -196,14 +205,14 @@ class BasePredictor:
|
|
196
205
|
"""Post-process predictions for an image and return them."""
|
197
206
|
return preds
|
198
207
|
|
199
|
-
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
208
|
+
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
|
200
209
|
"""
|
201
210
|
Perform inference on an image or stream.
|
202
211
|
|
203
212
|
Args:
|
204
|
-
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor
|
213
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
205
214
|
Source for inference.
|
206
|
-
model (str | Path | torch.nn.Module
|
215
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
207
216
|
stream (bool): Whether to stream the inference results. If True, returns a generator.
|
208
217
|
*args (Any): Additional arguments for the inference method.
|
209
218
|
**kwargs (Any): Additional keyword arguments for the inference method.
|
@@ -226,9 +235,9 @@ class BasePredictor:
|
|
226
235
|
generator without storing results.
|
227
236
|
|
228
237
|
Args:
|
229
|
-
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor
|
238
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
230
239
|
Source for inference.
|
231
|
-
model (str | Path | torch.nn.Module
|
240
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
232
241
|
|
233
242
|
Note:
|
234
243
|
Do not modify this function or remove the generator. The generator ensures that no outputs are
|
@@ -270,9 +279,9 @@ class BasePredictor:
|
|
270
279
|
Stream real-time inference on camera feed and save results to file.
|
271
280
|
|
272
281
|
Args:
|
273
|
-
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor
|
282
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
274
283
|
Source for inference.
|
275
|
-
model (str | Path | torch.nn.Module
|
284
|
+
model (str | Path | torch.nn.Module, optional): Model for inference.
|
276
285
|
*args (Any): Additional arguments for the inference method.
|
277
286
|
**kwargs (Any): Additional keyword arguments for the inference method.
|
278
287
|
|
@@ -365,12 +374,12 @@ class BasePredictor:
|
|
365
374
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
366
375
|
self.run_callbacks("on_predict_end")
|
367
376
|
|
368
|
-
def setup_model(self, model, verbose=True):
|
377
|
+
def setup_model(self, model, verbose: bool = True):
|
369
378
|
"""
|
370
379
|
Initialize YOLO model with given parameters and set it to evaluation mode.
|
371
380
|
|
372
381
|
Args:
|
373
|
-
model (str | Path | torch.nn.Module
|
382
|
+
model (str | Path | torch.nn.Module, optional): Model to load or use.
|
374
383
|
verbose (bool): Whether to print verbose output.
|
375
384
|
"""
|
376
385
|
self.model = AutoBackend(
|
@@ -390,7 +399,7 @@ class BasePredictor:
|
|
390
399
|
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
|
391
400
|
self.model.eval()
|
392
401
|
|
393
|
-
def write_results(self, i, p, im, s):
|
402
|
+
def write_results(self, i: int, p: Path, im: torch.Tensor, s: List[str]) -> str:
|
394
403
|
"""
|
395
404
|
Write inference results to a file or directory.
|
396
405
|
|
@@ -441,7 +450,7 @@ class BasePredictor:
|
|
441
450
|
|
442
451
|
return string
|
443
452
|
|
444
|
-
def save_predicted_images(self, save_path="", frame=0):
|
453
|
+
def save_predicted_images(self, save_path: str = "", frame: int = 0):
|
445
454
|
"""
|
446
455
|
Save video predictions as mp4 or images as jpg at specified path.
|
447
456
|
|
@@ -475,7 +484,7 @@ class BasePredictor:
|
|
475
484
|
else:
|
476
485
|
cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
|
477
486
|
|
478
|
-
def show(self, p=""):
|
487
|
+
def show(self, p: str = ""):
|
479
488
|
"""Display an image in a window."""
|
480
489
|
im = self.plotted_img
|
481
490
|
if platform.system() == "Linux" and p not in self.windows:
|
@@ -490,6 +499,6 @@ class BasePredictor:
|
|
490
499
|
for callback in self.callbacks.get(event, []):
|
491
500
|
callback(self)
|
492
501
|
|
493
|
-
def add_callback(self, event: str, func):
|
502
|
+
def add_callback(self, event: str, func: callable):
|
494
503
|
"""Add a callback function for a specific event."""
|
495
504
|
self.callbacks[event].append(func)
|