ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
- ultralytics-8.3.145.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/engine/exporter.py
CHANGED
@@ -222,11 +222,53 @@ def arange_patch(args):
|
|
222
222
|
|
223
223
|
class Exporter:
|
224
224
|
"""
|
225
|
-
A class for exporting
|
225
|
+
A class for exporting YOLO models to various formats.
|
226
|
+
|
227
|
+
This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
|
228
|
+
TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
|
229
|
+
process for each supported format.
|
226
230
|
|
227
231
|
Attributes:
|
228
|
-
args (SimpleNamespace): Configuration for the exporter.
|
229
|
-
callbacks (
|
232
|
+
args (SimpleNamespace): Configuration arguments for the exporter.
|
233
|
+
callbacks (dict): Dictionary of callback functions for different export events.
|
234
|
+
im (torch.Tensor): Input tensor for model inference during export.
|
235
|
+
model (torch.nn.Module): The YOLO model to be exported.
|
236
|
+
file (Path): Path to the model file being exported.
|
237
|
+
output_shape (tuple): Shape of the model output tensor(s).
|
238
|
+
pretty_name (str): Formatted model name for display purposes.
|
239
|
+
metadata (dict): Model metadata including description, author, version, etc.
|
240
|
+
device (torch.device): Device on which the model is loaded.
|
241
|
+
imgsz (tuple): Input image size for the model.
|
242
|
+
|
243
|
+
Methods:
|
244
|
+
__call__: Main export method that handles the export process.
|
245
|
+
get_int8_calibration_dataloader: Build dataloader for INT8 calibration.
|
246
|
+
export_torchscript: Export model to TorchScript format.
|
247
|
+
export_onnx: Export model to ONNX format.
|
248
|
+
export_openvino: Export model to OpenVINO format.
|
249
|
+
export_paddle: Export model to PaddlePaddle format.
|
250
|
+
export_mnn: Export model to MNN format.
|
251
|
+
export_ncnn: Export model to NCNN format.
|
252
|
+
export_coreml: Export model to CoreML format.
|
253
|
+
export_engine: Export model to TensorRT format.
|
254
|
+
export_saved_model: Export model to TensorFlow SavedModel format.
|
255
|
+
export_pb: Export model to TensorFlow GraphDef format.
|
256
|
+
export_tflite: Export model to TensorFlow Lite format.
|
257
|
+
export_edgetpu: Export model to Edge TPU format.
|
258
|
+
export_tfjs: Export model to TensorFlow.js format.
|
259
|
+
export_rknn: Export model to RKNN format.
|
260
|
+
export_imx: Export model to IMX format.
|
261
|
+
|
262
|
+
Examples:
|
263
|
+
Export a YOLOv8 model to ONNX format
|
264
|
+
>>> from ultralytics.engine.exporter import Exporter
|
265
|
+
>>> exporter = Exporter()
|
266
|
+
>>> exporter(model="yolov8n.pt") # exports to yolov8n.onnx
|
267
|
+
|
268
|
+
Export with specific arguments
|
269
|
+
>>> args = {"format": "onnx", "dynamic": True, "half": True}
|
270
|
+
>>> exporter = Exporter(overrides=args)
|
271
|
+
>>> exporter(model="yolov8n.pt")
|
230
272
|
"""
|
231
273
|
|
232
274
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -536,7 +578,7 @@ class Exporter:
|
|
536
578
|
|
537
579
|
@try_export
|
538
580
|
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
539
|
-
"""YOLO TorchScript
|
581
|
+
"""Export YOLO model to TorchScript format."""
|
540
582
|
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
541
583
|
f = self.file.with_suffix(".torchscript")
|
542
584
|
|
@@ -553,7 +595,7 @@ class Exporter:
|
|
553
595
|
|
554
596
|
@try_export
|
555
597
|
def export_onnx(self, prefix=colorstr("ONNX:")):
|
556
|
-
"""YOLO ONNX
|
598
|
+
"""Export YOLO model to ONNX format."""
|
557
599
|
requirements = ["onnx>=1.12.0,<1.18.0"]
|
558
600
|
if self.args.simplify:
|
559
601
|
requirements += ["onnxslim>=0.1.53", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
|
@@ -612,7 +654,7 @@ class Exporter:
|
|
612
654
|
|
613
655
|
@try_export
|
614
656
|
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
615
|
-
"""YOLO OpenVINO
|
657
|
+
"""Export YOLO model to OpenVINO format."""
|
616
658
|
if MACOS:
|
617
659
|
msg = "OpenVINO error in macOS>=15.4 https://github.com/openvinotoolkit/openvino/issues/30023"
|
618
660
|
check_version(MACOS_VERSION, "<15.4", name="macOS ", hard=True, msg=msg)
|
@@ -689,7 +731,7 @@ class Exporter:
|
|
689
731
|
|
690
732
|
@try_export
|
691
733
|
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
692
|
-
"""YOLO
|
734
|
+
"""Export YOLO model to PaddlePaddle format."""
|
693
735
|
assert not IS_JETSON, "Jetson Paddle exports not supported yet"
|
694
736
|
check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle>=3.0.0", "x2paddle"))
|
695
737
|
import x2paddle # noqa
|
@@ -704,7 +746,7 @@ class Exporter:
|
|
704
746
|
|
705
747
|
@try_export
|
706
748
|
def export_mnn(self, prefix=colorstr("MNN:")):
|
707
|
-
"""YOLO MNN
|
749
|
+
"""Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN."""
|
708
750
|
f_onnx, _ = self.export_onnx() # get onnx model first
|
709
751
|
|
710
752
|
check_requirements("MNN>=2.9.6")
|
@@ -729,7 +771,7 @@ class Exporter:
|
|
729
771
|
|
730
772
|
@try_export
|
731
773
|
def export_ncnn(self, prefix=colorstr("NCNN:")):
|
732
|
-
"""YOLO NCNN
|
774
|
+
"""Export YOLO model to NCNN format using PNNX https://github.com/pnnx/pnnx."""
|
733
775
|
check_requirements("ncnn")
|
734
776
|
import ncnn # noqa
|
735
777
|
|
@@ -797,7 +839,7 @@ class Exporter:
|
|
797
839
|
|
798
840
|
@try_export
|
799
841
|
def export_coreml(self, prefix=colorstr("CoreML:")):
|
800
|
-
"""YOLO CoreML
|
842
|
+
"""Export YOLO model to CoreML format."""
|
801
843
|
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
802
844
|
check_requirements("coremltools>=8.0")
|
803
845
|
import coremltools as ct # noqa
|
@@ -876,7 +918,7 @@ class Exporter:
|
|
876
918
|
|
877
919
|
@try_export
|
878
920
|
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
|
879
|
-
"""YOLO TensorRT
|
921
|
+
"""Export YOLO model to TensorRT format https://developer.nvidia.com/tensorrt."""
|
880
922
|
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
881
923
|
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
|
882
924
|
|
@@ -912,7 +954,7 @@ class Exporter:
|
|
912
954
|
|
913
955
|
@try_export
|
914
956
|
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
915
|
-
"""YOLO TensorFlow SavedModel
|
957
|
+
"""Export YOLO model to TensorFlow SavedModel format."""
|
916
958
|
cuda = torch.cuda.is_available()
|
917
959
|
try:
|
918
960
|
import tensorflow as tf # noqa
|
@@ -1002,7 +1044,7 @@ class Exporter:
|
|
1002
1044
|
|
1003
1045
|
@try_export
|
1004
1046
|
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
1005
|
-
"""YOLO TensorFlow GraphDef *.pb
|
1047
|
+
"""Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
|
1006
1048
|
import tensorflow as tf # noqa
|
1007
1049
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
1008
1050
|
|
@@ -1018,7 +1060,7 @@ class Exporter:
|
|
1018
1060
|
|
1019
1061
|
@try_export
|
1020
1062
|
def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
|
1021
|
-
"""YOLO TensorFlow Lite
|
1063
|
+
"""Export YOLO model to TensorFlow Lite format."""
|
1022
1064
|
# BUG https://github.com/ultralytics/ultralytics/issues/13436
|
1023
1065
|
import tensorflow as tf # noqa
|
1024
1066
|
|
@@ -1034,7 +1076,7 @@ class Exporter:
|
|
1034
1076
|
|
1035
1077
|
@try_export
|
1036
1078
|
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
1037
|
-
"""YOLO Edge TPU
|
1079
|
+
"""Export YOLO model to Edge TPU format https://coral.ai/docs/edgetpu/models-intro/."""
|
1038
1080
|
cmd = "edgetpu_compiler --version"
|
1039
1081
|
help_url = "https://coral.ai/docs/edgetpu/compiler/"
|
1040
1082
|
assert LINUX, f"export only supported on Linux. See {help_url}"
|
@@ -1069,7 +1111,7 @@ class Exporter:
|
|
1069
1111
|
|
1070
1112
|
@try_export
|
1071
1113
|
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
1072
|
-
"""YOLO TensorFlow.js
|
1114
|
+
"""Export YOLO model to TensorFlow.js format."""
|
1073
1115
|
check_requirements("tensorflowjs")
|
1074
1116
|
import tensorflow as tf
|
1075
1117
|
import tensorflowjs as tfjs # noqa
|
@@ -1102,7 +1144,7 @@ class Exporter:
|
|
1102
1144
|
|
1103
1145
|
@try_export
|
1104
1146
|
def export_rknn(self, prefix=colorstr("RKNN:")):
|
1105
|
-
"""YOLO RKNN
|
1147
|
+
"""Export YOLO model to RKNN format."""
|
1106
1148
|
LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
|
1107
1149
|
|
1108
1150
|
check_requirements("rknn-toolkit2")
|
@@ -1129,7 +1171,7 @@ class Exporter:
|
|
1129
1171
|
|
1130
1172
|
@try_export
|
1131
1173
|
def export_imx(self, prefix=colorstr("IMX:")):
|
1132
|
-
"""YOLO IMX
|
1174
|
+
"""Export YOLO model to IMX format."""
|
1133
1175
|
gptq = False
|
1134
1176
|
assert LINUX, (
|
1135
1177
|
"export only supported on Linux. "
|
@@ -1212,6 +1254,8 @@ class Exporter:
|
|
1212
1254
|
)
|
1213
1255
|
|
1214
1256
|
class NMSWrapper(torch.nn.Module):
|
1257
|
+
"""Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
|
1258
|
+
|
1215
1259
|
def __init__(
|
1216
1260
|
self,
|
1217
1261
|
model: torch.nn.Module,
|
@@ -1220,13 +1264,13 @@ class Exporter:
|
|
1220
1264
|
max_detections: int = 300,
|
1221
1265
|
):
|
1222
1266
|
"""
|
1223
|
-
|
1267
|
+
Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
1224
1268
|
|
1225
1269
|
Args:
|
1226
|
-
model (nn.Module): Model instance.
|
1270
|
+
model (torch.nn.Module): Model instance.
|
1227
1271
|
score_threshold (float): Score threshold for non-maximum suppression.
|
1228
1272
|
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
1229
|
-
max_detections (
|
1273
|
+
max_detections (int): The number of detections to return.
|
1230
1274
|
"""
|
1231
1275
|
super().__init__()
|
1232
1276
|
self.model = model
|
@@ -1235,6 +1279,7 @@ class Exporter:
|
|
1235
1279
|
self.max_detections = max_detections
|
1236
1280
|
|
1237
1281
|
def forward(self, images):
|
1282
|
+
"""Forward pass with model inference and NMS post-processing."""
|
1238
1283
|
# model inference
|
1239
1284
|
outputs = self.model(images)
|
1240
1285
|
|
@@ -1289,7 +1334,7 @@ class Exporter:
|
|
1289
1334
|
zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
|
1290
1335
|
|
1291
1336
|
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
1292
|
-
"""
|
1337
|
+
"""Create CoreML pipeline with NMS for YOLO detection models."""
|
1293
1338
|
import coremltools as ct # noqa
|
1294
1339
|
|
1295
1340
|
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
@@ -1395,7 +1440,7 @@ class Exporter:
|
|
1395
1440
|
return model
|
1396
1441
|
|
1397
1442
|
def add_callback(self, event: str, callback):
|
1398
|
-
"""
|
1443
|
+
"""Append the given callback to the specified event."""
|
1399
1444
|
self.callbacks[event].append(callback)
|
1400
1445
|
|
1401
1446
|
def run_callbacks(self, event: str):
|
@@ -1408,7 +1453,13 @@ class IOSDetectModel(torch.nn.Module):
|
|
1408
1453
|
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
|
1409
1454
|
|
1410
1455
|
def __init__(self, model, im):
|
1411
|
-
"""
|
1456
|
+
"""
|
1457
|
+
Initialize the IOSDetectModel class with a YOLO model and example image.
|
1458
|
+
|
1459
|
+
Args:
|
1460
|
+
model (torch.nn.Module): The YOLO model to wrap.
|
1461
|
+
im (torch.Tensor): Example input tensor with shape (B, C, H, W).
|
1462
|
+
"""
|
1412
1463
|
super().__init__()
|
1413
1464
|
_, _, h, w = im.shape # batch, channel, height, width
|
1414
1465
|
self.model = model
|
@@ -1432,7 +1483,7 @@ class NMSModel(torch.nn.Module):
|
|
1432
1483
|
Initialize the NMSModel.
|
1433
1484
|
|
1434
1485
|
Args:
|
1435
|
-
model (torch.nn.
|
1486
|
+
model (torch.nn.Module): The model to wrap with NMS postprocessing.
|
1436
1487
|
args (Namespace): The export arguments.
|
1437
1488
|
"""
|
1438
1489
|
super().__init__()
|
@@ -1443,13 +1494,14 @@ class NMSModel(torch.nn.Module):
|
|
1443
1494
|
|
1444
1495
|
def forward(self, x):
|
1445
1496
|
"""
|
1446
|
-
|
1497
|
+
Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
1447
1498
|
|
1448
1499
|
Args:
|
1449
1500
|
x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
|
1450
1501
|
|
1451
1502
|
Returns:
|
1452
|
-
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
|
1503
|
+
(torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
|
1504
|
+
number of detections after NMS.
|
1453
1505
|
"""
|
1454
1506
|
from functools import partial
|
1455
1507
|
|
ultralytics/engine/model.py
CHANGED
@@ -48,25 +48,25 @@ class Model(torch.nn.Module):
|
|
48
48
|
|
49
49
|
Methods:
|
50
50
|
__call__: Alias for the predict method, enabling the model instance to be callable.
|
51
|
-
_new:
|
52
|
-
_load:
|
53
|
-
_check_is_pytorch_model:
|
54
|
-
reset_weights:
|
55
|
-
load:
|
56
|
-
save:
|
57
|
-
info:
|
58
|
-
fuse:
|
59
|
-
predict:
|
60
|
-
track:
|
61
|
-
val:
|
62
|
-
benchmark:
|
63
|
-
export:
|
64
|
-
train:
|
65
|
-
tune:
|
66
|
-
_apply:
|
67
|
-
add_callback:
|
68
|
-
clear_callback:
|
69
|
-
reset_callbacks:
|
51
|
+
_new: Initialize a new model based on a configuration file.
|
52
|
+
_load: Load a model from a checkpoint file.
|
53
|
+
_check_is_pytorch_model: Ensure that the model is a PyTorch model.
|
54
|
+
reset_weights: Reset the model's weights to their initial state.
|
55
|
+
load: Load model weights from a specified file.
|
56
|
+
save: Save the current state of the model to a file.
|
57
|
+
info: Log or return information about the model.
|
58
|
+
fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.
|
59
|
+
predict: Perform object detection predictions.
|
60
|
+
track: Perform object tracking.
|
61
|
+
val: Validate the model on a dataset.
|
62
|
+
benchmark: Benchmark the model on various export formats.
|
63
|
+
export: Export the model to different formats.
|
64
|
+
train: Train the model on a dataset.
|
65
|
+
tune: Perform hyperparameter tuning.
|
66
|
+
_apply: Apply a function to the model's tensors.
|
67
|
+
add_callback: Add a callback function for an event.
|
68
|
+
clear_callback: Clear all callbacks for an event.
|
69
|
+
reset_callbacks: Reset all callbacks to their default functions.
|
70
70
|
|
71
71
|
Examples:
|
72
72
|
>>> from ultralytics import YOLO
|
@@ -94,7 +94,7 @@ class Model(torch.nn.Module):
|
|
94
94
|
Args:
|
95
95
|
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
|
96
96
|
model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
97
|
-
task (str
|
97
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
98
98
|
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
99
99
|
operations.
|
100
100
|
|
@@ -242,9 +242,9 @@ class Model(torch.nn.Module):
|
|
242
242
|
|
243
243
|
Args:
|
244
244
|
cfg (str): Path to the model configuration file in YAML format.
|
245
|
-
task (str
|
246
|
-
model (torch.nn.Module
|
247
|
-
a new one.
|
245
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
246
|
+
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
|
247
|
+
creating a new one.
|
248
248
|
verbose (bool): If True, displays model information during loading.
|
249
249
|
|
250
250
|
Raises:
|
@@ -276,7 +276,7 @@ class Model(torch.nn.Module):
|
|
276
276
|
|
277
277
|
Args:
|
278
278
|
weights (str): Path to the model weights file to be loaded.
|
279
|
-
task (str
|
279
|
+
task (str, optional): The task associated with the model. If None, it will be inferred from the model.
|
280
280
|
|
281
281
|
Raises:
|
282
282
|
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
@@ -367,7 +367,7 @@ class Model(torch.nn.Module):
|
|
367
367
|
name and shape and transfers them to the model.
|
368
368
|
|
369
369
|
Args:
|
370
|
-
weights (
|
370
|
+
weights (str | Path): Path to the weights file or a weights object.
|
371
371
|
|
372
372
|
Returns:
|
373
373
|
(Model): The instance of the class with loaded weights.
|
@@ -501,7 +501,7 @@ class Model(torch.nn.Module):
|
|
501
501
|
**kwargs: Any,
|
502
502
|
) -> List[Results]:
|
503
503
|
"""
|
504
|
-
|
504
|
+
Perform predictions on the given image source using the YOLO model.
|
505
505
|
|
506
506
|
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
507
507
|
It supports predictions with custom predictors or the default predictor method. The method handles different
|
@@ -512,7 +512,7 @@ class Model(torch.nn.Module):
|
|
512
512
|
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
|
513
513
|
images, numpy arrays, and torch tensors.
|
514
514
|
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
515
|
-
predictor (BasePredictor
|
515
|
+
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
516
516
|
If None, the method uses a default predictor.
|
517
517
|
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
518
518
|
|
@@ -562,14 +562,14 @@ class Model(torch.nn.Module):
|
|
562
562
|
**kwargs: Any,
|
563
563
|
) -> List[Results]:
|
564
564
|
"""
|
565
|
-
|
565
|
+
Conduct object tracking on the specified input source using the registered trackers.
|
566
566
|
|
567
567
|
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
568
568
|
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
569
569
|
The method registers trackers if not already present and can persist them between calls.
|
570
570
|
|
571
571
|
Args:
|
572
|
-
source (
|
572
|
+
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor, optional): Input source for object
|
573
573
|
tracking. Can be a file path, URL, or video stream.
|
574
574
|
stream (bool): If True, treats the input source as a continuous video stream.
|
575
575
|
persist (bool): If True, persists trackers between different calls to this method.
|
@@ -611,8 +611,8 @@ class Model(torch.nn.Module):
|
|
611
611
|
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
612
612
|
|
613
613
|
Args:
|
614
|
-
validator (ultralytics.engine.validator.BaseValidator
|
615
|
-
validating the model.
|
614
|
+
validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class
|
615
|
+
for validating the model.
|
616
616
|
**kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
|
617
617
|
|
618
618
|
Returns:
|
@@ -634,10 +634,7 @@ class Model(torch.nn.Module):
|
|
634
634
|
self.metrics = validator.metrics
|
635
635
|
return validator.metrics
|
636
636
|
|
637
|
-
def benchmark(
|
638
|
-
self,
|
639
|
-
**kwargs: Any,
|
640
|
-
):
|
637
|
+
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
|
641
638
|
"""
|
642
639
|
Benchmark the model across various export formats to evaluate performance.
|
643
640
|
|
@@ -647,14 +644,14 @@ class Model(torch.nn.Module):
|
|
647
644
|
defaults, and any additional user-provided keyword arguments.
|
648
645
|
|
649
646
|
Args:
|
647
|
+
data (str): Path to the dataset for benchmarking.
|
648
|
+
verbose (bool): Whether to print detailed benchmark information.
|
649
|
+
format (str): Export format name for specific benchmarking.
|
650
650
|
**kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:
|
651
|
-
- data (str): Path to the dataset for benchmarking.
|
652
651
|
- imgsz (int | List[int]): Image size for benchmarking.
|
653
652
|
- half (bool): Whether to use half-precision (FP16) mode.
|
654
653
|
- int8 (bool): Whether to use int8 precision mode.
|
655
654
|
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
656
|
-
- verbose (bool): Whether to print detailed benchmark information.
|
657
|
-
- format (str): Export format name for specific benchmarking.
|
658
655
|
|
659
656
|
Returns:
|
660
657
|
(dict): A dictionary containing the results of the benchmarking process, including metrics for
|
@@ -671,17 +668,21 @@ class Model(torch.nn.Module):
|
|
671
668
|
self._check_is_pytorch_model()
|
672
669
|
from ultralytics.utils.benchmarks import benchmark
|
673
670
|
|
671
|
+
from .exporter import export_formats
|
672
|
+
|
674
673
|
custom = {"verbose": False} # method defaults
|
675
674
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
675
|
+
fmts = export_formats()
|
676
|
+
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, []))
|
677
|
+
export_kwargs = {k: v for k, v in args.items() if k in export_args - set(["batch"])}
|
676
678
|
return benchmark(
|
677
679
|
model=self,
|
678
|
-
data=
|
680
|
+
data=data, # if no 'data' argument passed set data=None for default datasets
|
679
681
|
imgsz=args["imgsz"],
|
680
|
-
half=args["half"],
|
681
|
-
int8=args["int8"],
|
682
682
|
device=args["device"],
|
683
|
-
verbose=
|
684
|
-
format=
|
683
|
+
verbose=verbose,
|
684
|
+
format=format,
|
685
|
+
**export_kwargs,
|
685
686
|
)
|
686
687
|
|
687
688
|
def export(
|
@@ -738,7 +739,7 @@ class Model(torch.nn.Module):
|
|
738
739
|
**kwargs: Any,
|
739
740
|
):
|
740
741
|
"""
|
741
|
-
|
742
|
+
Train the model using the specified dataset and training configuration.
|
742
743
|
|
743
744
|
This method facilitates model training with a range of customizable settings. It supports training with a
|
744
745
|
custom trainer or the default training approach. The method handles scenarios such as resuming training
|
@@ -749,7 +750,7 @@ class Model(torch.nn.Module):
|
|
749
750
|
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
750
751
|
|
751
752
|
Args:
|
752
|
-
trainer (BaseTrainer
|
753
|
+
trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
|
753
754
|
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
754
755
|
data (str): Path to dataset configuration file.
|
755
756
|
epochs (int): Number of training epochs.
|
@@ -810,7 +811,7 @@ class Model(torch.nn.Module):
|
|
810
811
|
**kwargs: Any,
|
811
812
|
):
|
812
813
|
"""
|
813
|
-
|
814
|
+
Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
814
815
|
|
815
816
|
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
|
816
817
|
When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
|
@@ -881,7 +882,7 @@ class Model(torch.nn.Module):
|
|
881
882
|
@property
|
882
883
|
def names(self) -> Dict[int, str]:
|
883
884
|
"""
|
884
|
-
|
885
|
+
Retrieve the class names associated with the loaded model.
|
885
886
|
|
886
887
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
887
888
|
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
@@ -935,7 +936,7 @@ class Model(torch.nn.Module):
|
|
935
936
|
@property
|
936
937
|
def transforms(self):
|
937
938
|
"""
|
938
|
-
|
939
|
+
Retrieve the transformations applied to the input data of the loaded model.
|
939
940
|
|
940
941
|
This property returns the transformations if they are defined in the model. The transforms
|
941
942
|
typically include preprocessing steps like resizing, normalization, and data augmentation
|
@@ -982,7 +983,7 @@ class Model(torch.nn.Module):
|
|
982
983
|
|
983
984
|
def clear_callback(self, event: str) -> None:
|
984
985
|
"""
|
985
|
-
|
986
|
+
Clear all callback functions registered for a specified event.
|
986
987
|
|
987
988
|
This method removes all custom and default callback functions associated with the given event.
|
988
989
|
It resets the callback list for the specified event to an empty list, effectively removing all
|
@@ -1062,7 +1063,7 @@ class Model(torch.nn.Module):
|
|
1062
1063
|
|
1063
1064
|
def _smart_load(self, key: str):
|
1064
1065
|
"""
|
1065
|
-
Intelligently
|
1066
|
+
Intelligently load the appropriate module based on the model task.
|
1066
1067
|
|
1067
1068
|
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
1068
1069
|
based on the current task of the model and the provided key. It uses the task_map dictionary to determine
|
@@ -1092,7 +1093,7 @@ class Model(torch.nn.Module):
|
|
1092
1093
|
@property
|
1093
1094
|
def task_map(self) -> dict:
|
1094
1095
|
"""
|
1095
|
-
|
1096
|
+
Provide a mapping from model tasks to corresponding classes for different modes.
|
1096
1097
|
|
1097
1098
|
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
1098
1099
|
to a nested dictionary. The nested dictionary contains mappings for different operational modes
|