ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -7
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +184 -75
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -222,11 +222,53 @@ def arange_patch(args):
222
222
 
223
223
  class Exporter:
224
224
  """
225
- A class for exporting a model.
225
+ A class for exporting YOLO models to various formats.
226
+
227
+ This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
228
+ TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
229
+ process for each supported format.
226
230
 
227
231
  Attributes:
228
- args (SimpleNamespace): Configuration for the exporter.
229
- callbacks (list, optional): List of callback functions.
232
+ args (SimpleNamespace): Configuration arguments for the exporter.
233
+ callbacks (dict): Dictionary of callback functions for different export events.
234
+ im (torch.Tensor): Input tensor for model inference during export.
235
+ model (torch.nn.Module): The YOLO model to be exported.
236
+ file (Path): Path to the model file being exported.
237
+ output_shape (tuple): Shape of the model output tensor(s).
238
+ pretty_name (str): Formatted model name for display purposes.
239
+ metadata (dict): Model metadata including description, author, version, etc.
240
+ device (torch.device): Device on which the model is loaded.
241
+ imgsz (tuple): Input image size for the model.
242
+
243
+ Methods:
244
+ __call__: Main export method that handles the export process.
245
+ get_int8_calibration_dataloader: Build dataloader for INT8 calibration.
246
+ export_torchscript: Export model to TorchScript format.
247
+ export_onnx: Export model to ONNX format.
248
+ export_openvino: Export model to OpenVINO format.
249
+ export_paddle: Export model to PaddlePaddle format.
250
+ export_mnn: Export model to MNN format.
251
+ export_ncnn: Export model to NCNN format.
252
+ export_coreml: Export model to CoreML format.
253
+ export_engine: Export model to TensorRT format.
254
+ export_saved_model: Export model to TensorFlow SavedModel format.
255
+ export_pb: Export model to TensorFlow GraphDef format.
256
+ export_tflite: Export model to TensorFlow Lite format.
257
+ export_edgetpu: Export model to Edge TPU format.
258
+ export_tfjs: Export model to TensorFlow.js format.
259
+ export_rknn: Export model to RKNN format.
260
+ export_imx: Export model to IMX format.
261
+
262
+ Examples:
263
+ Export a YOLOv8 model to ONNX format
264
+ >>> from ultralytics.engine.exporter import Exporter
265
+ >>> exporter = Exporter()
266
+ >>> exporter(model="yolov8n.pt") # exports to yolov8n.onnx
267
+
268
+ Export with specific arguments
269
+ >>> args = {"format": "onnx", "dynamic": True, "half": True}
270
+ >>> exporter = Exporter(overrides=args)
271
+ >>> exporter(model="yolov8n.pt")
230
272
  """
231
273
 
232
274
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@@ -536,7 +578,7 @@ class Exporter:
536
578
 
537
579
  @try_export
538
580
  def export_torchscript(self, prefix=colorstr("TorchScript:")):
539
- """YOLO TorchScript model export."""
581
+ """Export YOLO model to TorchScript format."""
540
582
  LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
541
583
  f = self.file.with_suffix(".torchscript")
542
584
 
@@ -553,7 +595,7 @@ class Exporter:
553
595
 
554
596
  @try_export
555
597
  def export_onnx(self, prefix=colorstr("ONNX:")):
556
- """YOLO ONNX export."""
598
+ """Export YOLO model to ONNX format."""
557
599
  requirements = ["onnx>=1.12.0,<1.18.0"]
558
600
  if self.args.simplify:
559
601
  requirements += ["onnxslim>=0.1.53", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
@@ -612,7 +654,7 @@ class Exporter:
612
654
 
613
655
  @try_export
614
656
  def export_openvino(self, prefix=colorstr("OpenVINO:")):
615
- """YOLO OpenVINO export."""
657
+ """Export YOLO model to OpenVINO format."""
616
658
  if MACOS:
617
659
  msg = "OpenVINO error in macOS>=15.4 https://github.com/openvinotoolkit/openvino/issues/30023"
618
660
  check_version(MACOS_VERSION, "<15.4", name="macOS ", hard=True, msg=msg)
@@ -689,7 +731,7 @@ class Exporter:
689
731
 
690
732
  @try_export
691
733
  def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
692
- """YOLO Paddle export."""
734
+ """Export YOLO model to PaddlePaddle format."""
693
735
  assert not IS_JETSON, "Jetson Paddle exports not supported yet"
694
736
  check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle>=3.0.0", "x2paddle"))
695
737
  import x2paddle # noqa
@@ -704,7 +746,7 @@ class Exporter:
704
746
 
705
747
  @try_export
706
748
  def export_mnn(self, prefix=colorstr("MNN:")):
707
- """YOLO MNN export using MNN https://github.com/alibaba/MNN."""
749
+ """Export YOLO model to MNN format using MNN https://github.com/alibaba/MNN."""
708
750
  f_onnx, _ = self.export_onnx() # get onnx model first
709
751
 
710
752
  check_requirements("MNN>=2.9.6")
@@ -729,7 +771,7 @@ class Exporter:
729
771
 
730
772
  @try_export
731
773
  def export_ncnn(self, prefix=colorstr("NCNN:")):
732
- """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""
774
+ """Export YOLO model to NCNN format using PNNX https://github.com/pnnx/pnnx."""
733
775
  check_requirements("ncnn")
734
776
  import ncnn # noqa
735
777
 
@@ -797,7 +839,7 @@ class Exporter:
797
839
 
798
840
  @try_export
799
841
  def export_coreml(self, prefix=colorstr("CoreML:")):
800
- """YOLO CoreML export."""
842
+ """Export YOLO model to CoreML format."""
801
843
  mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
802
844
  check_requirements("coremltools>=8.0")
803
845
  import coremltools as ct # noqa
@@ -876,7 +918,7 @@ class Exporter:
876
918
 
877
919
  @try_export
878
920
  def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
879
- """YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
921
+ """Export YOLO model to TensorRT format https://developer.nvidia.com/tensorrt."""
880
922
  assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
881
923
  f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
882
924
 
@@ -912,7 +954,7 @@ class Exporter:
912
954
 
913
955
  @try_export
914
956
  def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
915
- """YOLO TensorFlow SavedModel export."""
957
+ """Export YOLO model to TensorFlow SavedModel format."""
916
958
  cuda = torch.cuda.is_available()
917
959
  try:
918
960
  import tensorflow as tf # noqa
@@ -1002,7 +1044,7 @@ class Exporter:
1002
1044
 
1003
1045
  @try_export
1004
1046
  def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1005
- """YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen-Graph-TensorFlow."""
1047
+ """Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
1006
1048
  import tensorflow as tf # noqa
1007
1049
  from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
1008
1050
 
@@ -1018,7 +1060,7 @@ class Exporter:
1018
1060
 
1019
1061
  @try_export
1020
1062
  def export_tflite(self, prefix=colorstr("TensorFlow Lite:")):
1021
- """YOLO TensorFlow Lite export."""
1063
+ """Export YOLO model to TensorFlow Lite format."""
1022
1064
  # BUG https://github.com/ultralytics/ultralytics/issues/13436
1023
1065
  import tensorflow as tf # noqa
1024
1066
 
@@ -1034,7 +1076,7 @@ class Exporter:
1034
1076
 
1035
1077
  @try_export
1036
1078
  def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
1037
- """YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
1079
+ """Export YOLO model to Edge TPU format https://coral.ai/docs/edgetpu/models-intro/."""
1038
1080
  cmd = "edgetpu_compiler --version"
1039
1081
  help_url = "https://coral.ai/docs/edgetpu/compiler/"
1040
1082
  assert LINUX, f"export only supported on Linux. See {help_url}"
@@ -1069,7 +1111,7 @@ class Exporter:
1069
1111
 
1070
1112
  @try_export
1071
1113
  def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1072
- """YOLO TensorFlow.js export."""
1114
+ """Export YOLO model to TensorFlow.js format."""
1073
1115
  check_requirements("tensorflowjs")
1074
1116
  import tensorflow as tf
1075
1117
  import tensorflowjs as tfjs # noqa
@@ -1102,7 +1144,7 @@ class Exporter:
1102
1144
 
1103
1145
  @try_export
1104
1146
  def export_rknn(self, prefix=colorstr("RKNN:")):
1105
- """YOLO RKNN model export."""
1147
+ """Export YOLO model to RKNN format."""
1106
1148
  LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...")
1107
1149
 
1108
1150
  check_requirements("rknn-toolkit2")
@@ -1129,7 +1171,7 @@ class Exporter:
1129
1171
 
1130
1172
  @try_export
1131
1173
  def export_imx(self, prefix=colorstr("IMX:")):
1132
- """YOLO IMX export."""
1174
+ """Export YOLO model to IMX format."""
1133
1175
  gptq = False
1134
1176
  assert LINUX, (
1135
1177
  "export only supported on Linux. "
@@ -1212,6 +1254,8 @@ class Exporter:
1212
1254
  )
1213
1255
 
1214
1256
  class NMSWrapper(torch.nn.Module):
1257
+ """Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
1258
+
1215
1259
  def __init__(
1216
1260
  self,
1217
1261
  model: torch.nn.Module,
@@ -1220,13 +1264,13 @@ class Exporter:
1220
1264
  max_detections: int = 300,
1221
1265
  ):
1222
1266
  """
1223
- Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
1267
+ Initialize NMSWrapper with PyTorch Module and NMS parameters.
1224
1268
 
1225
1269
  Args:
1226
- model (nn.Module): Model instance.
1270
+ model (torch.nn.Module): Model instance.
1227
1271
  score_threshold (float): Score threshold for non-maximum suppression.
1228
1272
  iou_threshold (float): Intersection over union threshold for non-maximum suppression.
1229
- max_detections (float): The number of detections to return.
1273
+ max_detections (int): The number of detections to return.
1230
1274
  """
1231
1275
  super().__init__()
1232
1276
  self.model = model
@@ -1235,6 +1279,7 @@ class Exporter:
1235
1279
  self.max_detections = max_detections
1236
1280
 
1237
1281
  def forward(self, images):
1282
+ """Forward pass with model inference and NMS post-processing."""
1238
1283
  # model inference
1239
1284
  outputs = self.model(images)
1240
1285
 
@@ -1289,7 +1334,7 @@ class Exporter:
1289
1334
  zf.writestr("metadata.json", json.dumps(self.metadata, indent=2))
1290
1335
 
1291
1336
  def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
1292
- """YOLO CoreML pipeline."""
1337
+ """Create CoreML pipeline with NMS for YOLO detection models."""
1293
1338
  import coremltools as ct # noqa
1294
1339
 
1295
1340
  LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
@@ -1395,7 +1440,7 @@ class Exporter:
1395
1440
  return model
1396
1441
 
1397
1442
  def add_callback(self, event: str, callback):
1398
- """Appends the given callback."""
1443
+ """Append the given callback to the specified event."""
1399
1444
  self.callbacks[event].append(callback)
1400
1445
 
1401
1446
  def run_callbacks(self, event: str):
@@ -1408,7 +1453,13 @@ class IOSDetectModel(torch.nn.Module):
1408
1453
  """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
1409
1454
 
1410
1455
  def __init__(self, model, im):
1411
- """Initialize the IOSDetectModel class with a YOLO model and example image."""
1456
+ """
1457
+ Initialize the IOSDetectModel class with a YOLO model and example image.
1458
+
1459
+ Args:
1460
+ model (torch.nn.Module): The YOLO model to wrap.
1461
+ im (torch.Tensor): Example input tensor with shape (B, C, H, W).
1462
+ """
1412
1463
  super().__init__()
1413
1464
  _, _, h, w = im.shape # batch, channel, height, width
1414
1465
  self.model = model
@@ -1432,7 +1483,7 @@ class NMSModel(torch.nn.Module):
1432
1483
  Initialize the NMSModel.
1433
1484
 
1434
1485
  Args:
1435
- model (torch.nn.module): The model to wrap with NMS postprocessing.
1486
+ model (torch.nn.Module): The model to wrap with NMS postprocessing.
1436
1487
  args (Namespace): The export arguments.
1437
1488
  """
1438
1489
  super().__init__()
@@ -1443,13 +1494,14 @@ class NMSModel(torch.nn.Module):
1443
1494
 
1444
1495
  def forward(self, x):
1445
1496
  """
1446
- Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1497
+ Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1447
1498
 
1448
1499
  Args:
1449
1500
  x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
1450
1501
 
1451
1502
  Returns:
1452
- (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS.
1503
+ (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
1504
+ number of detections after NMS.
1453
1505
  """
1454
1506
  from functools import partial
1455
1507
 
@@ -48,25 +48,25 @@ class Model(torch.nn.Module):
48
48
 
49
49
  Methods:
50
50
  __call__: Alias for the predict method, enabling the model instance to be callable.
51
- _new: Initializes a new model based on a configuration file.
52
- _load: Loads a model from a checkpoint file.
53
- _check_is_pytorch_model: Ensures that the model is a PyTorch model.
54
- reset_weights: Resets the model's weights to their initial state.
55
- load: Loads model weights from a specified file.
56
- save: Saves the current state of the model to a file.
57
- info: Logs or returns information about the model.
58
- fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference.
59
- predict: Performs object detection predictions.
60
- track: Performs object tracking.
61
- val: Validates the model on a dataset.
62
- benchmark: Benchmarks the model on various export formats.
63
- export: Exports the model to different formats.
64
- train: Trains the model on a dataset.
65
- tune: Performs hyperparameter tuning.
66
- _apply: Applies a function to the model's tensors.
67
- add_callback: Adds a callback function for an event.
68
- clear_callback: Clears all callbacks for an event.
69
- reset_callbacks: Resets all callbacks to their default functions.
51
+ _new: Initialize a new model based on a configuration file.
52
+ _load: Load a model from a checkpoint file.
53
+ _check_is_pytorch_model: Ensure that the model is a PyTorch model.
54
+ reset_weights: Reset the model's weights to their initial state.
55
+ load: Load model weights from a specified file.
56
+ save: Save the current state of the model to a file.
57
+ info: Log or return information about the model.
58
+ fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.
59
+ predict: Perform object detection predictions.
60
+ track: Perform object tracking.
61
+ val: Validate the model on a dataset.
62
+ benchmark: Benchmark the model on various export formats.
63
+ export: Export the model to different formats.
64
+ train: Train the model on a dataset.
65
+ tune: Perform hyperparameter tuning.
66
+ _apply: Apply a function to the model's tensors.
67
+ add_callback: Add a callback function for an event.
68
+ clear_callback: Clear all callbacks for an event.
69
+ reset_callbacks: Reset all callbacks to their default functions.
70
70
 
71
71
  Examples:
72
72
  >>> from ultralytics import YOLO
@@ -94,7 +94,7 @@ class Model(torch.nn.Module):
94
94
  Args:
95
95
  model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
96
96
  model name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
97
- task (str | None): The task type associated with the YOLO model, specifying its application domain.
97
+ task (str, optional): The specific task for the model. If None, it will be inferred from the config.
98
98
  verbose (bool): If True, enables verbose output during the model's initialization and subsequent
99
99
  operations.
100
100
 
@@ -242,9 +242,9 @@ class Model(torch.nn.Module):
242
242
 
243
243
  Args:
244
244
  cfg (str): Path to the model configuration file in YAML format.
245
- task (str | None): The specific task for the model. If None, it will be inferred from the config.
246
- model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
247
- a new one.
245
+ task (str, optional): The specific task for the model. If None, it will be inferred from the config.
246
+ model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
247
+ creating a new one.
248
248
  verbose (bool): If True, displays model information during loading.
249
249
 
250
250
  Raises:
@@ -276,7 +276,7 @@ class Model(torch.nn.Module):
276
276
 
277
277
  Args:
278
278
  weights (str): Path to the model weights file to be loaded.
279
- task (str | None): The task associated with the model. If None, it will be inferred from the model.
279
+ task (str, optional): The task associated with the model. If None, it will be inferred from the model.
280
280
 
281
281
  Raises:
282
282
  FileNotFoundError: If the specified weights file does not exist or is inaccessible.
@@ -367,7 +367,7 @@ class Model(torch.nn.Module):
367
367
  name and shape and transfers them to the model.
368
368
 
369
369
  Args:
370
- weights (Union[str, Path]): Path to the weights file or a weights object.
370
+ weights (str | Path): Path to the weights file or a weights object.
371
371
 
372
372
  Returns:
373
373
  (Model): The instance of the class with loaded weights.
@@ -501,7 +501,7 @@ class Model(torch.nn.Module):
501
501
  **kwargs: Any,
502
502
  ) -> List[Results]:
503
503
  """
504
- Performs predictions on the given image source using the YOLO model.
504
+ Perform predictions on the given image source using the YOLO model.
505
505
 
506
506
  This method facilitates the prediction process, allowing various configurations through keyword arguments.
507
507
  It supports predictions with custom predictors or the default predictor method. The method handles different
@@ -512,7 +512,7 @@ class Model(torch.nn.Module):
512
512
  of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
513
513
  images, numpy arrays, and torch tensors.
514
514
  stream (bool): If True, treats the input source as a continuous stream for predictions.
515
- predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
515
+ predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
516
516
  If None, the method uses a default predictor.
517
517
  **kwargs (Any): Additional keyword arguments for configuring the prediction process.
518
518
 
@@ -562,14 +562,14 @@ class Model(torch.nn.Module):
562
562
  **kwargs: Any,
563
563
  ) -> List[Results]:
564
564
  """
565
- Conducts object tracking on the specified input source using the registered trackers.
565
+ Conduct object tracking on the specified input source using the registered trackers.
566
566
 
567
567
  This method performs object tracking using the model's predictors and optionally registered trackers. It handles
568
568
  various input sources such as file paths or video streams, and supports customization through keyword arguments.
569
569
  The method registers trackers if not already present and can persist them between calls.
570
570
 
571
571
  Args:
572
- source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
572
+ source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor, optional): Input source for object
573
573
  tracking. Can be a file path, URL, or video stream.
574
574
  stream (bool): If True, treats the input source as a continuous video stream.
575
575
  persist (bool): If True, persists trackers between different calls to this method.
@@ -611,8 +611,8 @@ class Model(torch.nn.Module):
611
611
  configurations, method-specific defaults, and user-provided arguments to configure the validation process.
612
612
 
613
613
  Args:
614
- validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
615
- validating the model.
614
+ validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class
615
+ for validating the model.
616
616
  **kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
617
617
 
618
618
  Returns:
@@ -738,7 +738,7 @@ class Model(torch.nn.Module):
738
738
  **kwargs: Any,
739
739
  ):
740
740
  """
741
- Trains the model using the specified dataset and training configuration.
741
+ Train the model using the specified dataset and training configuration.
742
742
 
743
743
  This method facilitates model training with a range of customizable settings. It supports training with a
744
744
  custom trainer or the default training approach. The method handles scenarios such as resuming training
@@ -749,7 +749,7 @@ class Model(torch.nn.Module):
749
749
  configurations, method-specific defaults, and user-provided arguments to configure the training process.
750
750
 
751
751
  Args:
752
- trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
752
+ trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
753
753
  **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
754
754
  data (str): Path to dataset configuration file.
755
755
  epochs (int): Number of training epochs.
@@ -810,7 +810,7 @@ class Model(torch.nn.Module):
810
810
  **kwargs: Any,
811
811
  ):
812
812
  """
813
- Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
813
+ Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
814
814
 
815
815
  This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
816
816
  When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
@@ -881,7 +881,7 @@ class Model(torch.nn.Module):
881
881
  @property
882
882
  def names(self) -> Dict[int, str]:
883
883
  """
884
- Retrieves the class names associated with the loaded model.
884
+ Retrieve the class names associated with the loaded model.
885
885
 
886
886
  This property returns the class names if they are defined in the model. It checks the class names for validity
887
887
  using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
@@ -935,7 +935,7 @@ class Model(torch.nn.Module):
935
935
  @property
936
936
  def transforms(self):
937
937
  """
938
- Retrieves the transformations applied to the input data of the loaded model.
938
+ Retrieve the transformations applied to the input data of the loaded model.
939
939
 
940
940
  This property returns the transformations if they are defined in the model. The transforms
941
941
  typically include preprocessing steps like resizing, normalization, and data augmentation
@@ -982,7 +982,7 @@ class Model(torch.nn.Module):
982
982
 
983
983
  def clear_callback(self, event: str) -> None:
984
984
  """
985
- Clears all callback functions registered for a specified event.
985
+ Clear all callback functions registered for a specified event.
986
986
 
987
987
  This method removes all custom and default callback functions associated with the given event.
988
988
  It resets the callback list for the specified event to an empty list, effectively removing all
@@ -1062,7 +1062,7 @@ class Model(torch.nn.Module):
1062
1062
 
1063
1063
  def _smart_load(self, key: str):
1064
1064
  """
1065
- Intelligently loads the appropriate module based on the model task.
1065
+ Intelligently load the appropriate module based on the model task.
1066
1066
 
1067
1067
  This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
1068
1068
  based on the current task of the model and the provided key. It uses the task_map dictionary to determine
@@ -1092,7 +1092,7 @@ class Model(torch.nn.Module):
1092
1092
  @property
1093
1093
  def task_map(self) -> dict:
1094
1094
  """
1095
- Provides a mapping from model tasks to corresponding classes for different modes.
1095
+ Provide a mapping from model tasks to corresponding classes for different modes.
1096
1096
 
1097
1097
  This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
1098
1098
  to a nested dictionary. The nested dictionary contains mappings for different operational modes
@@ -36,6 +36,7 @@ import platform
36
36
  import re
37
37
  import threading
38
38
  from pathlib import Path
39
+ from typing import Any, Dict, List, Optional, Union
39
40
 
40
41
  import cv2
41
42
  import numpy as np
@@ -78,15 +79,15 @@ class BasePredictor:
78
79
  data (dict): Data configuration.
79
80
  device (torch.device): Device used for prediction.
80
81
  dataset (Dataset): Dataset used for prediction.
81
- vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
82
- plotted_img (numpy.ndarray): Last plotted image.
82
+ vid_writer (Dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
83
+ plotted_img (np.ndarray): Last plotted image.
83
84
  source_type (SimpleNamespace): Type of input source.
84
85
  seen (int): Number of images processed.
85
- windows (list): List of window names for visualization.
86
+ windows (List[str]): List of window names for visualization.
86
87
  batch (tuple): Current batch data.
87
- results (list): Current batch results.
88
+ results (List[Any]): Current batch results.
88
89
  transforms (callable): Image transforms for classification.
89
- callbacks (dict): Callback functions for different events.
90
+ callbacks (Dict[str, List[callable]]): Callback functions for different events.
90
91
  txt_path (Path): Path to save text results.
91
92
  _lock (threading.Lock): Lock for thread-safe inference.
92
93
 
@@ -105,14 +106,19 @@ class BasePredictor:
105
106
  add_callback: Register a new callback function.
106
107
  """
107
108
 
108
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
109
+ def __init__(
110
+ self,
111
+ cfg=DEFAULT_CFG,
112
+ overrides: Optional[Dict[str, Any]] = None,
113
+ _callbacks: Optional[Dict[str, List[callable]]] = None,
114
+ ):
109
115
  """
110
116
  Initialize the BasePredictor class.
111
117
 
112
118
  Args:
113
119
  cfg (str | dict): Path to a configuration file or a configuration dictionary.
114
- overrides (dict | None): Configuration overrides.
115
- _callbacks (dict | None): Dictionary of callback functions.
120
+ overrides (dict, optional): Configuration overrides.
121
+ _callbacks (dict, optional): Dictionary of callback functions.
116
122
  """
117
123
  self.args = get_cfg(cfg, overrides)
118
124
  self.save_dir = get_save_dir(self.args)
@@ -141,12 +147,15 @@ class BasePredictor:
141
147
  self._lock = threading.Lock() # for automatic thread-safe inference
142
148
  callbacks.add_integration_callbacks(self)
143
149
 
144
- def preprocess(self, im):
150
+ def preprocess(self, im: Union[torch.Tensor, List[np.ndarray]]) -> torch.Tensor:
145
151
  """
146
- Prepares input image before inference.
152
+ Prepare input image before inference.
147
153
 
148
154
  Args:
149
- im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
155
+ im (torch.Tensor | List[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
156
+
157
+ Returns:
158
+ (torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
150
159
  """
151
160
  not_tensor = not isinstance(im, torch.Tensor)
152
161
  if not_tensor:
@@ -163,7 +172,7 @@ class BasePredictor:
163
172
  im /= 255 # 0 - 255 to 0.0 - 1.0
164
173
  return im
165
174
 
166
- def inference(self, im, *args, **kwargs):
175
+ def inference(self, im: torch.Tensor, *args, **kwargs):
167
176
  """Run inference on a given image using the specified model and arguments."""
168
177
  visualize = (
169
178
  increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
@@ -172,15 +181,15 @@ class BasePredictor:
172
181
  )
173
182
  return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
174
183
 
175
- def pre_transform(self, im):
184
+ def pre_transform(self, im: List[np.ndarray]) -> List[np.ndarray]:
176
185
  """
177
186
  Pre-transform input image before inference.
178
187
 
179
188
  Args:
180
- im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
189
+ im (List[np.ndarray]): List of images with shape [(H, W, 3) x N].
181
190
 
182
191
  Returns:
183
- (List[np.ndarray]): A list of transformed images.
192
+ (List[np.ndarray]): List of transformed images.
184
193
  """
185
194
  same_shapes = len({x.shape for x in im}) == 1
186
195
  letterbox = LetterBox(
@@ -196,14 +205,14 @@ class BasePredictor:
196
205
  """Post-process predictions for an image and return them."""
197
206
  return preds
198
207
 
199
- def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
208
+ def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
200
209
  """
201
210
  Perform inference on an image or stream.
202
211
 
203
212
  Args:
204
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
213
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
205
214
  Source for inference.
206
- model (str | Path | torch.nn.Module | None): Model for inference.
215
+ model (str | Path | torch.nn.Module, optional): Model for inference.
207
216
  stream (bool): Whether to stream the inference results. If True, returns a generator.
208
217
  *args (Any): Additional arguments for the inference method.
209
218
  **kwargs (Any): Additional keyword arguments for the inference method.
@@ -226,9 +235,9 @@ class BasePredictor:
226
235
  generator without storing results.
227
236
 
228
237
  Args:
229
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
238
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
230
239
  Source for inference.
231
- model (str | Path | torch.nn.Module | None): Model for inference.
240
+ model (str | Path | torch.nn.Module, optional): Model for inference.
232
241
 
233
242
  Note:
234
243
  Do not modify this function or remove the generator. The generator ensures that no outputs are
@@ -270,9 +279,9 @@ class BasePredictor:
270
279
  Stream real-time inference on camera feed and save results to file.
271
280
 
272
281
  Args:
273
- source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
282
+ source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor, optional):
274
283
  Source for inference.
275
- model (str | Path | torch.nn.Module | None): Model for inference.
284
+ model (str | Path | torch.nn.Module, optional): Model for inference.
276
285
  *args (Any): Additional arguments for the inference method.
277
286
  **kwargs (Any): Additional keyword arguments for the inference method.
278
287
 
@@ -365,12 +374,12 @@ class BasePredictor:
365
374
  LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
366
375
  self.run_callbacks("on_predict_end")
367
376
 
368
- def setup_model(self, model, verbose=True):
377
+ def setup_model(self, model, verbose: bool = True):
369
378
  """
370
379
  Initialize YOLO model with given parameters and set it to evaluation mode.
371
380
 
372
381
  Args:
373
- model (str | Path | torch.nn.Module | None): Model to load or use.
382
+ model (str | Path | torch.nn.Module, optional): Model to load or use.
374
383
  verbose (bool): Whether to print verbose output.
375
384
  """
376
385
  self.model = AutoBackend(
@@ -390,7 +399,7 @@ class BasePredictor:
390
399
  self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
391
400
  self.model.eval()
392
401
 
393
- def write_results(self, i, p, im, s):
402
+ def write_results(self, i: int, p: Path, im: torch.Tensor, s: List[str]) -> str:
394
403
  """
395
404
  Write inference results to a file or directory.
396
405
 
@@ -441,7 +450,7 @@ class BasePredictor:
441
450
 
442
451
  return string
443
452
 
444
- def save_predicted_images(self, save_path="", frame=0):
453
+ def save_predicted_images(self, save_path: str = "", frame: int = 0):
445
454
  """
446
455
  Save video predictions as mp4 or images as jpg at specified path.
447
456
 
@@ -475,7 +484,7 @@ class BasePredictor:
475
484
  else:
476
485
  cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
477
486
 
478
- def show(self, p=""):
487
+ def show(self, p: str = ""):
479
488
  """Display an image in a window."""
480
489
  im = self.plotted_img
481
490
  if platform.system() == "Linux" and p not in self.windows:
@@ -490,6 +499,6 @@ class BasePredictor:
490
499
  for callback in self.callbacks.get(event, []):
491
500
  callback(self)
492
501
 
493
- def add_callback(self, event: str, func):
502
+ def add_callback(self, event: str, func: callable):
494
503
  """Add a callback function for a specific event."""
495
504
  self.callbacks[event].append(func)