ultralytics 8.2.81__py3-none-any.whl → 8.2.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (97) hide show
  1. tests/test_solutions.py +0 -4
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +21 -21
  4. ultralytics/data/annotator.py +1 -1
  5. ultralytics/data/augment.py +58 -58
  6. ultralytics/data/base.py +3 -3
  7. ultralytics/data/converter.py +7 -8
  8. ultralytics/data/explorer/explorer.py +7 -23
  9. ultralytics/data/loaders.py +2 -2
  10. ultralytics/data/split_dota.py +11 -3
  11. ultralytics/data/utils.py +6 -10
  12. ultralytics/engine/exporter.py +2 -4
  13. ultralytics/engine/model.py +47 -47
  14. ultralytics/engine/predictor.py +1 -1
  15. ultralytics/engine/results.py +28 -28
  16. ultralytics/engine/trainer.py +11 -8
  17. ultralytics/engine/tuner.py +7 -8
  18. ultralytics/engine/validator.py +3 -5
  19. ultralytics/hub/__init__.py +5 -5
  20. ultralytics/hub/auth.py +6 -2
  21. ultralytics/hub/session.py +3 -5
  22. ultralytics/models/fastsam/model.py +13 -10
  23. ultralytics/models/fastsam/predict.py +2 -2
  24. ultralytics/models/fastsam/utils.py +0 -1
  25. ultralytics/models/nas/model.py +4 -4
  26. ultralytics/models/nas/predict.py +1 -2
  27. ultralytics/models/nas/val.py +1 -1
  28. ultralytics/models/rtdetr/predict.py +1 -1
  29. ultralytics/models/rtdetr/train.py +1 -1
  30. ultralytics/models/rtdetr/val.py +1 -1
  31. ultralytics/models/sam/model.py +11 -11
  32. ultralytics/models/sam/modules/decoders.py +7 -4
  33. ultralytics/models/sam/modules/sam.py +9 -1
  34. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  35. ultralytics/models/sam/modules/transformer.py +0 -2
  36. ultralytics/models/sam/modules/utils.py +1 -1
  37. ultralytics/models/sam/predict.py +10 -10
  38. ultralytics/models/utils/loss.py +29 -17
  39. ultralytics/models/utils/ops.py +1 -5
  40. ultralytics/models/yolo/classify/predict.py +1 -1
  41. ultralytics/models/yolo/classify/train.py +1 -1
  42. ultralytics/models/yolo/classify/val.py +1 -1
  43. ultralytics/models/yolo/detect/predict.py +1 -1
  44. ultralytics/models/yolo/detect/train.py +1 -1
  45. ultralytics/models/yolo/detect/val.py +1 -1
  46. ultralytics/models/yolo/model.py +6 -2
  47. ultralytics/models/yolo/obb/predict.py +1 -1
  48. ultralytics/models/yolo/obb/train.py +1 -1
  49. ultralytics/models/yolo/obb/val.py +2 -2
  50. ultralytics/models/yolo/pose/predict.py +1 -1
  51. ultralytics/models/yolo/pose/train.py +1 -1
  52. ultralytics/models/yolo/pose/val.py +1 -1
  53. ultralytics/models/yolo/segment/predict.py +1 -1
  54. ultralytics/models/yolo/segment/train.py +1 -1
  55. ultralytics/models/yolo/segment/val.py +1 -1
  56. ultralytics/models/yolo/world/train.py +1 -1
  57. ultralytics/nn/autobackend.py +2 -2
  58. ultralytics/nn/modules/__init__.py +2 -2
  59. ultralytics/nn/modules/block.py +8 -20
  60. ultralytics/nn/modules/conv.py +1 -3
  61. ultralytics/nn/modules/head.py +16 -31
  62. ultralytics/nn/modules/transformer.py +0 -1
  63. ultralytics/nn/modules/utils.py +0 -1
  64. ultralytics/nn/tasks.py +11 -9
  65. ultralytics/solutions/__init__.py +1 -0
  66. ultralytics/solutions/ai_gym.py +0 -2
  67. ultralytics/solutions/analytics.py +1 -6
  68. ultralytics/solutions/heatmap.py +0 -1
  69. ultralytics/solutions/object_counter.py +0 -2
  70. ultralytics/solutions/queue_management.py +0 -2
  71. ultralytics/trackers/basetrack.py +1 -1
  72. ultralytics/trackers/byte_tracker.py +2 -2
  73. ultralytics/trackers/utils/gmc.py +5 -5
  74. ultralytics/trackers/utils/kalman_filter.py +1 -1
  75. ultralytics/trackers/utils/matching.py +1 -5
  76. ultralytics/utils/__init__.py +137 -24
  77. ultralytics/utils/autobatch.py +7 -4
  78. ultralytics/utils/benchmarks.py +6 -14
  79. ultralytics/utils/callbacks/base.py +0 -1
  80. ultralytics/utils/callbacks/comet.py +0 -1
  81. ultralytics/utils/callbacks/tensorboard.py +0 -1
  82. ultralytics/utils/checks.py +15 -18
  83. ultralytics/utils/downloads.py +6 -7
  84. ultralytics/utils/files.py +3 -4
  85. ultralytics/utils/instance.py +17 -7
  86. ultralytics/utils/metrics.py +16 -16
  87. ultralytics/utils/ops.py +8 -8
  88. ultralytics/utils/plotting.py +25 -35
  89. ultralytics/utils/tal.py +27 -18
  90. ultralytics/utils/torch_utils.py +12 -13
  91. ultralytics/utils/tuner.py +2 -3
  92. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/METADATA +4 -3
  93. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/RECORD +97 -97
  94. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/WHEEL +1 -1
  95. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/LICENSE +0 -0
  96. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/entry_points.txt +0 -0
  97. {ultralytics-8.2.81.dist-info → ultralytics-8.2.83.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,17 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
16
16
 
17
17
  Args:
18
18
  model (torch.nn.Module): YOLO model to check batch size for.
19
- imgsz (int): Image size used for training.
20
- amp (bool): If True, use automatic mixed precision (AMP) for training.
19
+ imgsz (int, optional): Image size used for training.
20
+ amp (bool, optional): Use automatic mixed precision if True.
21
+ batch (float, optional): Fraction of GPU memory to use. If -1, use default.
21
22
 
22
23
  Returns:
23
24
  (int): Optimal batch size computed using the autobatch() function.
24
- """
25
25
 
26
+ Note:
27
+ If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
28
+ Otherwise, a default fraction of 0.6 is used.
29
+ """
26
30
  with autocast(enabled=amp):
27
31
  return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
28
32
 
@@ -40,7 +44,6 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
40
44
  Returns:
41
45
  (int): The optimal batch size.
42
46
  """
43
-
44
47
  # Check device
45
48
  prefix = colorstr("AutoBatch: ")
46
49
  LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
@@ -71,7 +71,7 @@ def benchmark(
71
71
  ```python
72
72
  from ultralytics.utils.benchmarks import benchmark
73
73
 
74
- benchmark(model='yolov8n.pt', imgsz=640)
74
+ benchmark(model="yolov8n.pt", imgsz=640)
75
75
  ```
76
76
  """
77
77
  import pandas as pd # scope for faster 'import ultralytics'
@@ -97,20 +97,17 @@ def benchmark(
97
97
  assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
98
98
  assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
99
99
  assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
100
- assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
101
100
  if i in {3, 5}: # CoreML and OpenVINO
102
101
  assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
103
102
  if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite
104
103
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
105
104
  if i in {9, 10}: # TF EdgeTPU and TF.js
106
105
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
107
- assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet"
108
106
  if i in {11}: # Paddle
109
107
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
110
108
  assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
111
109
  if i in {12}: # NCNN
112
110
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
113
- assert not is_end2end, "End-to-end models not supported by NCNN yet"
114
111
  if "cpu" in device.type:
115
112
  assert cpu, "inference not supported on CPU"
116
113
  if "cuda" in device.type:
@@ -130,6 +127,8 @@ def benchmark(
130
127
  assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
131
128
  assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
132
129
  assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
130
+ if i in {12}:
131
+ assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
133
132
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
134
133
 
135
134
  # Validate
@@ -182,7 +181,6 @@ class RF100Benchmark:
182
181
  Args:
183
182
  api_key (str): The API key.
184
183
  """
185
-
186
184
  check_requirements("roboflow")
187
185
  from roboflow import Roboflow
188
186
 
@@ -195,7 +193,6 @@ class RF100Benchmark:
195
193
  Args:
196
194
  ds_link_txt (str): Path to dataset_links file.
197
195
  """
198
-
199
196
  (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
200
197
  os.chdir("rf-100")
201
198
  os.mkdir("ultralytics-benchmarks")
@@ -225,7 +222,6 @@ class RF100Benchmark:
225
222
  Args:
226
223
  path (str): YAML file path.
227
224
  """
228
-
229
225
  with open(path, "r") as file:
230
226
  yaml_data = yaml.safe_load(file)
231
227
  yaml_data["train"] = "train/images"
@@ -302,7 +298,7 @@ class ProfileModels:
302
298
  ```python
303
299
  from ultralytics.utils.benchmarks import ProfileModels
304
300
 
305
- ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'], imgsz=640).profile()
301
+ ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640).profile()
306
302
  ```
307
303
  """
308
304
 
@@ -393,9 +389,7 @@ class ProfileModels:
393
389
  return [Path(file) for file in sorted(files)]
394
390
 
395
391
  def get_onnx_model_info(self, onnx_file: str):
396
- """Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
397
- file.
398
- """
392
+ """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
399
393
  return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
400
394
 
401
395
  @staticmethod
@@ -440,9 +434,7 @@ class ProfileModels:
440
434
  return np.mean(run_times), np.std(run_times)
441
435
 
442
436
  def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
443
- """Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
444
- times.
445
- """
437
+ """Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
446
438
  check_requirements("onnxruntime")
447
439
  import onnxruntime as ort
448
440
 
@@ -192,7 +192,6 @@ def add_integration_callbacks(instance):
192
192
  instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
193
193
  of callback lists.
194
194
  """
195
-
196
195
  # Load HUB callbacks
197
196
  from .hub import callbacks as hub_cb
198
197
 
@@ -114,7 +114,6 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
114
114
 
115
115
  This function rescales the bounding box labels to the original image shape.
116
116
  """
117
-
118
117
  resized_image_height, resized_image_width = resized_image_shape
119
118
 
120
119
  # Convert normalized xywh format predictions to xyxy in resized scale format
@@ -34,7 +34,6 @@ def _log_scalars(scalars, step=0):
34
34
 
35
35
  def _log_tensorboard_graph(trainer):
36
36
  """Log model graph to TensorBoard."""
37
-
38
37
  # Input image
39
38
  imgsz = trainer.args.imgsz
40
39
  imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
@@ -23,6 +23,7 @@ from ultralytics.utils import (
23
23
  ASSETS,
24
24
  AUTOINSTALL,
25
25
  IS_COLAB,
26
+ IS_GIT_DIR,
26
27
  IS_JUPYTER,
27
28
  IS_KAGGLE,
28
29
  IS_PIP_PACKAGE,
@@ -61,10 +62,9 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
61
62
  ```python
62
63
  from ultralytics.utils.checks import parse_requirements
63
64
 
64
- parse_requirements(package='ultralytics')
65
+ parse_requirements(package="ultralytics")
65
66
  ```
66
67
  """
67
-
68
68
  if package:
69
69
  requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
70
70
  else:
@@ -196,16 +196,16 @@ def check_version(
196
196
  Example:
197
197
  ```python
198
198
  # Check if current version is exactly 22.04
199
- check_version(current='22.04', required='==22.04')
199
+ check_version(current="22.04", required="==22.04")
200
200
 
201
201
  # Check if current version is greater than or equal to 22.04
202
- check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
202
+ check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
203
203
 
204
204
  # Check if current version is less than or equal to 22.04
205
- check_version(current='22.04', required='<=22.04')
205
+ check_version(current="22.04", required="<=22.04")
206
206
 
207
207
  # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
208
- check_version(current='21.10', required='>20.04,<22.04')
208
+ check_version(current="21.10", required=">20.04,<22.04")
209
209
  ```
210
210
  """
211
211
  if not current: # if current is '' or None
@@ -256,7 +256,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
256
256
  """
257
257
  Returns the latest version of a PyPI package without downloading or installing it.
258
258
 
259
- Parameters:
259
+ Args:
260
260
  package_name (str): The name of the package to find the latest version for.
261
261
 
262
262
  Returns:
@@ -352,16 +352,15 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
352
352
  from ultralytics.utils.checks import check_requirements
353
353
 
354
354
  # Check a requirements.txt file
355
- check_requirements('path/to/requirements.txt')
355
+ check_requirements("path/to/requirements.txt")
356
356
 
357
357
  # Check a single package
358
- check_requirements('ultralytics>=8.0.0')
358
+ check_requirements("ultralytics>=8.0.0")
359
359
 
360
360
  # Check multiple packages
361
- check_requirements(['numpy', 'ultralytics>=8.0.0'])
361
+ check_requirements(["numpy", "ultralytics>=8.0.0"])
362
362
  ```
363
363
  """
364
-
365
364
  prefix = colorstr("red", "bold", "requirements:")
366
365
  check_python() # check python version
367
366
  check_torchvision() # check torch-torchvision compatibility
@@ -421,7 +420,6 @@ def check_torchvision():
421
420
  The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
422
421
  Torchvision versions.
423
422
  """
424
-
425
423
  # Compatibility table
426
424
  compatibility_table = {
427
425
  "2.3": ["0.18"],
@@ -582,10 +580,9 @@ def check_yolo(verbose=True, device=""):
582
580
 
583
581
  def collect_system_info():
584
582
  """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
585
-
586
583
  import psutil
587
584
 
588
- from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR
585
+ from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
589
586
  from ultralytics.utils.torch_utils import get_cpu_info
590
587
 
591
588
  ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
@@ -622,9 +619,9 @@ def collect_system_info():
622
619
 
623
620
  def check_amp(model):
624
621
  """
625
- This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
626
- fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
627
- be disabled during training.
622
+ Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks fail, it means
623
+ there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
624
+ during training.
628
625
 
629
626
  Args:
630
627
  model (nn.Module): A YOLOv8 model instance.
@@ -634,7 +631,7 @@ def check_amp(model):
634
631
  from ultralytics import YOLO
635
632
  from ultralytics.utils.checks import check_amp
636
633
 
637
- model = YOLO('yolov8n.pt').model.cuda()
634
+ model = YOLO("yolov8n.pt").model.cuda()
638
635
  check_amp(model)
639
636
  ```
640
637
 
@@ -75,7 +75,7 @@ def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
75
75
  ```python
76
76
  from ultralytics.utils.downloads import delete_dsstore
77
77
 
78
- delete_dsstore('path/to/dir')
78
+ delete_dsstore("path/to/dir")
79
79
  ```
80
80
 
81
81
  Note:
@@ -107,7 +107,7 @@ def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), p
107
107
  ```python
108
108
  from ultralytics.utils.downloads import zip_directory
109
109
 
110
- file = zip_directory('path/to/dir')
110
+ file = zip_directory("path/to/dir")
111
111
  ```
112
112
  """
113
113
  from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
@@ -153,7 +153,7 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
153
153
  ```python
154
154
  from ultralytics.utils.downloads import unzip_file
155
155
 
156
- dir = unzip_file('path/to/file.zip')
156
+ dir = unzip_file("path/to/file.zip")
157
157
  ```
158
158
  """
159
159
  from zipfile import BadZipFile, ZipFile, is_zipfile
@@ -392,10 +392,9 @@ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
392
392
 
393
393
  Example:
394
394
  ```python
395
- tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
395
+ tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
396
396
  ```
397
397
  """
398
-
399
398
  if version != "latest":
400
399
  version = f"tags/{version}" # i.e. tags/v6.2
401
400
  url = f"https://api.github.com/repos/{repo}/releases/{version}"
@@ -425,7 +424,7 @@ def attempt_download_asset(file, repo="ultralytics/assets", release="v8.2.0", **
425
424
 
426
425
  Example:
427
426
  ```python
428
- file_path = attempt_download_asset('yolov8n.pt', repo='ultralytics/assets', release='latest')
427
+ file_path = attempt_download_asset("yolov8n.pt", repo="ultralytics/assets", release="latest")
429
428
  ```
430
429
  """
431
430
  from ultralytics.utils import SETTINGS # scoped for circular import
@@ -480,7 +479,7 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
480
479
 
481
480
  Example:
482
481
  ```python
483
- download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
482
+ download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
484
483
  ```
485
484
  """
486
485
  dir = Path(dir)
@@ -28,13 +28,13 @@ class WorkingDirectory(contextlib.ContextDecorator):
28
28
  Examples:
29
29
  Using as a context manager:
30
30
  >>> with WorkingDirectory('/path/to/new/dir'):
31
- >>> # Perform operations in the new directory
31
+ >>> # Perform operations in the new directory
32
32
  >>> pass
33
33
 
34
34
  Using as a decorator:
35
35
  >>> @WorkingDirectory('/path/to/new/dir')
36
36
  >>> def some_function():
37
- >>> # Perform operations in the new directory
37
+ >>> # Perform operations in the new directory
38
38
  >>> pass
39
39
  """
40
40
 
@@ -69,9 +69,8 @@ def spaces_in_path(path):
69
69
  Use the context manager to handle paths with spaces:
70
70
  >>> from ultralytics.utils.files import spaces_in_path
71
71
  >>> with spaces_in_path('/path/with spaces') as new_path:
72
- >>> # Your code here
72
+ >>> # Your code here
73
73
  """
74
-
75
74
  # If path has spaces, replace them with underscores
76
75
  if " " in str(path):
77
76
  string = isinstance(path, str) # input type
@@ -96,8 +96,11 @@ class Bboxes:
96
96
 
97
97
  def mul(self, scale):
98
98
  """
99
+ Multiply bounding box coordinates by scale factor(s).
100
+
99
101
  Args:
100
- scale (tuple | list | int): the scale for four coords.
102
+ scale (int | tuple | list): Scale factor(s) for four coordinates.
103
+ If int, the same scale is applied to all coordinates.
101
104
  """
102
105
  if isinstance(scale, Number):
103
106
  scale = to_4tuple(scale)
@@ -110,8 +113,11 @@ class Bboxes:
110
113
 
111
114
  def add(self, offset):
112
115
  """
116
+ Add offset to bounding box coordinates.
117
+
113
118
  Args:
114
- offset (tuple | list | int): the offset for four coords.
119
+ offset (int | tuple | list): Offset(s) for four coordinates.
120
+ If int, the same offset is applied to all coordinates.
115
121
  """
116
122
  if isinstance(offset, Number):
117
123
  offset = to_4tuple(offset)
@@ -199,7 +205,7 @@ class Instances:
199
205
  instances = Instances(
200
206
  bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
201
207
  segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
202
- keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
208
+ keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),
203
209
  )
204
210
  ```
205
211
 
@@ -210,10 +216,14 @@ class Instances:
210
216
 
211
217
  def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
212
218
  """
219
+ Initialize the object with bounding boxes, segments, and keypoints.
220
+
213
221
  Args:
214
- bboxes (ndarray): bboxes with shape [N, 4].
215
- segments (list | ndarray): segments.
216
- keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
222
+ bboxes (np.ndarray): Bounding boxes, shape [N, 4].
223
+ segments (list | np.ndarray, optional): Segmentation masks. Defaults to None.
224
+ keypoints (np.ndarray, optional): Keypoints, shape [N, 17, 3] and format (x, y, visible). Defaults to None.
225
+ bbox_format (str, optional): Format of bboxes. Defaults to "xywh".
226
+ normalized (bool, optional): Whether the coordinates are normalized. Defaults to True.
217
227
  """
218
228
  self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
219
229
  self.keypoints = keypoints
@@ -230,7 +240,7 @@ class Instances:
230
240
  return self._bboxes.areas()
231
241
 
232
242
  def scale(self, scale_w, scale_h, bbox_only=False):
233
- """This might be similar with denormalize func but without normalized sign."""
243
+ """Similar to denormalize func but without normalized sign."""
234
244
  self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
235
245
  if bbox_only:
236
246
  return
@@ -30,7 +30,6 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
30
30
  Returns:
31
31
  (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
32
32
  """
33
-
34
33
  # Get the coordinates of bounding boxes
35
34
  b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
36
35
  b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
@@ -53,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
53
52
  def box_iou(box1, box2, eps=1e-7):
54
53
  """
55
54
  Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
56
- Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
55
+ Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
57
56
 
58
57
  Args:
59
58
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
@@ -63,7 +62,6 @@ def box_iou(box1, box2, eps=1e-7):
63
62
  Returns:
64
63
  (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
65
64
  """
66
-
67
65
  # NOTE: Need .float() to get accurate iou values
68
66
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
69
67
  (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
@@ -90,7 +88,6 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
90
88
  Returns:
91
89
  (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
92
90
  """
93
-
94
91
  # Get the coordinates of bounding boxes
95
92
  if xywh: # transform from xywh to xyxy
96
93
  (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
@@ -181,7 +178,7 @@ def _get_covariance_matrix(boxes):
181
178
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
182
179
 
183
180
  Returns:
184
- (torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes.
181
+ (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
185
182
  """
186
183
  # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
187
184
  gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
@@ -195,15 +192,22 @@ def _get_covariance_matrix(boxes):
195
192
 
196
193
  def probiou(obb1, obb2, CIoU=False, eps=1e-7):
197
194
  """
198
- Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
195
+ Calculate probabilistic IoU between oriented bounding boxes.
196
+
197
+ Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
199
198
 
200
199
  Args:
201
- obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
202
- obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
203
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
200
+ obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
201
+ obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
202
+ CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
203
+ eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
204
204
 
205
205
  Returns:
206
- (torch.Tensor): A tensor of shape (N, ) representing obb similarities.
206
+ (torch.Tensor): OBB similarities, shape (N,).
207
+
208
+ Note:
209
+ OBB format: [center_x, center_y, width, height, rotation_angle].
210
+ If CIoU is True, returns CIoU instead of IoU.
207
211
  """
208
212
  x1, y1 = obb1[..., :2].split(1, dim=-1)
209
213
  x2, y2 = obb2[..., :2].split(1, dim=-1)
@@ -507,7 +511,6 @@ def compute_ap(recall, precision):
507
511
  (np.ndarray): Precision envelope curve.
508
512
  (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
509
513
  """
510
-
511
514
  # Append sentinel values to beginning and end
512
515
  mrec = np.concatenate(([0.0], recall, [1.0]))
513
516
  mpre = np.concatenate(([1.0], precision, [0.0]))
@@ -560,7 +563,6 @@ def ap_per_class(
560
563
  x (np.ndarray): X-axis values for the curves. Shape: (1000,).
561
564
  prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
562
565
  """
563
-
564
566
  # Sort by objectness
565
567
  i = np.argsort(-conf)
566
568
  tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
@@ -792,8 +794,8 @@ class Metric(SimpleClass):
792
794
 
793
795
  class DetMetrics(SimpleClass):
794
796
  """
795
- This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
796
- (mAP) of an object detection model.
797
+ Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
798
+ object detection model.
797
799
 
798
800
  Args:
799
801
  save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
@@ -942,7 +944,6 @@ class SegmentMetrics(SimpleClass):
942
944
  pred_cls (list): List of predicted classes.
943
945
  target_cls (list): List of target classes.
944
946
  """
945
-
946
947
  results_mask = ap_per_class(
947
948
  tp_m,
948
949
  conf,
@@ -1084,7 +1085,6 @@ class PoseMetrics(SegmentMetrics):
1084
1085
  pred_cls (list): List of predicted classes.
1085
1086
  target_cls (list): List of target classes.
1086
1087
  """
1087
-
1088
1088
  results_pose = ap_per_class(
1089
1089
  tp_p,
1090
1090
  conf,
ultralytics/utils/ops.py CHANGED
@@ -141,14 +141,15 @@ def make_divisible(x, divisor):
141
141
 
142
142
  def nms_rotated(boxes, scores, threshold=0.45):
143
143
  """
144
- NMS for obbs, powered by probiou and fast-nms.
144
+ NMS for oriented bounding boxes using probiou and fast-nms.
145
145
 
146
146
  Args:
147
- boxes (torch.Tensor): (N, 5), xywhr.
148
- scores (torch.Tensor): (N, ).
149
- threshold (float): IoU threshold.
147
+ boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
148
+ scores (torch.Tensor): Confidence scores, shape (N,).
149
+ threshold (float, optional): IoU threshold. Defaults to 0.45.
150
150
 
151
151
  Returns:
152
+ (torch.Tensor): Indices of boxes to keep after NMS.
152
153
  """
153
154
  if len(boxes) == 0:
154
155
  return np.empty((0,), dtype=np.int8)
@@ -597,7 +598,7 @@ def ltwh2xyxy(x):
597
598
 
598
599
  def segments2boxes(segments):
599
600
  """
600
- It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
601
+ It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
601
602
 
602
603
  Args:
603
604
  segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
@@ -667,7 +668,6 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
667
668
  (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
668
669
  are the height and width of the input image. The mask is applied to the bounding boxes.
669
670
  """
670
-
671
671
  c, mh, mw = protos.shape # CHW
672
672
  ih, iw = shape
673
673
  masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
@@ -785,7 +785,7 @@ def regularize_rboxes(rboxes):
785
785
 
786
786
  def masks2segments(masks, strategy="largest"):
787
787
  """
788
- It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
788
+ It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
789
789
 
790
790
  Args:
791
791
  masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
@@ -823,7 +823,7 @@ def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
823
823
 
824
824
  def clean_str(s):
825
825
  """
826
- Cleans a string by replacing special characters with underscore _
826
+ Cleans a string by replacing special characters with '_' character.
827
827
 
828
828
  Args:
829
829
  s (str): a string needing special characters replaced