ultralytics 8.3.145__py3-none-any.whl → 8.3.147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. tests/__init__.py +3 -0
  2. tests/test_cli.py +2 -7
  3. tests/test_python.py +55 -18
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/__init__.py +0 -1
  6. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  7. ultralytics/data/augment.py +2 -2
  8. ultralytics/engine/model.py +4 -4
  9. ultralytics/engine/validator.py +1 -1
  10. ultralytics/models/nas/model.py +0 -8
  11. ultralytics/models/yolo/classify/val.py +5 -9
  12. ultralytics/models/yolo/detect/val.py +8 -17
  13. ultralytics/models/yolo/obb/val.py +24 -17
  14. ultralytics/models/yolo/pose/val.py +19 -14
  15. ultralytics/models/yolo/segment/val.py +52 -44
  16. ultralytics/nn/tasks.py +3 -0
  17. ultralytics/solutions/analytics.py +17 -9
  18. ultralytics/solutions/object_counter.py +2 -4
  19. ultralytics/trackers/bot_sort.py +4 -2
  20. ultralytics/utils/__init__.py +2 -3
  21. ultralytics/utils/benchmarks.py +15 -15
  22. ultralytics/utils/checks.py +10 -5
  23. ultralytics/utils/downloads.py +1 -0
  24. ultralytics/utils/metrics.py +52 -33
  25. ultralytics/utils/plotting.py +10 -7
  26. ultralytics/utils/torch_utils.py +2 -2
  27. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/METADATA +1 -1
  28. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/RECORD +32 -31
  29. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/WHEEL +1 -1
  30. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/entry_points.txt +0 -0
  31. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/licenses/LICENSE +0 -0
  32. {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/top_level.txt +0 -0
tests/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
3
4
  from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks
4
5
 
5
6
  # Constants used in tests
@@ -10,6 +11,8 @@ SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
10
11
  TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
11
12
  CUDA_IS_AVAILABLE = checks.cuda_is_available()
12
13
  CUDA_DEVICE_COUNT = checks.cuda_device_count()
14
+ TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
15
+ MODELS = frozenset(list(TASK2MODEL.values()) + ["yolo11n-grayscale.pt"])
13
16
 
14
17
  __all__ = (
15
18
  "MODEL",
tests/test_cli.py CHANGED
@@ -5,15 +5,10 @@ import subprocess
5
5
  import pytest
6
6
  from PIL import Image
7
7
 
8
- from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE
9
- from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
8
+ from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA
10
9
  from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks
11
10
  from ultralytics.utils.torch_utils import TORCH_1_9
12
11
 
13
- # Constants
14
- TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
15
- MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS]
16
-
17
12
 
18
13
  def run(cmd: str) -> None:
19
14
  """Execute a shell command using subprocess."""
@@ -44,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
44
39
  @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
45
40
  def test_predict(task: str, model: str, data: str) -> None:
46
41
  """Test YOLO prediction on provided sample assets for specified task and model."""
47
- run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
42
+ run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
48
43
 
49
44
 
50
45
  @pytest.mark.parametrize("model", MODELS)
tests/test_python.py CHANGED
@@ -12,9 +12,9 @@ import pytest
12
12
  import torch
13
13
  from PIL import Image
14
14
 
15
- from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
15
+ from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP
16
16
  from ultralytics import RTDETR, YOLO
17
- from ultralytics.cfg import MODELS, TASK2DATA, TASKS
17
+ from ultralytics.cfg import TASK2DATA, TASKS
18
18
  from ultralytics.data.build import load_inference_source
19
19
  from ultralytics.utils import (
20
20
  ARM64,
@@ -112,21 +112,22 @@ def test_predict_csv_single_row():
112
112
  @pytest.mark.parametrize("model_name", MODELS)
113
113
  def test_predict_img(model_name):
114
114
  """Test YOLO model predictions on various image input types and sources, including online images."""
115
+ channels = 1 if model_name == "yolo11n-grayscale.pt" else 3
115
116
  model = YOLO(WEIGHTS_DIR / model_name)
116
- im = cv2.imread(str(SOURCE)) # uint8 numpy array
117
+ im = cv2.imread(str(SOURCE), flags=cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR) # uint8 numpy array
117
118
  assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
118
119
  assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
119
- assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
120
+ assert len(model(torch.rand((2, channels, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
120
121
  assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
121
122
  assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
122
- assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
123
+ assert len(model(torch.zeros(320, 640, channels).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
123
124
  batch = [
124
125
  str(SOURCE), # filename
125
126
  Path(SOURCE), # Path
126
127
  "https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE, # URI
127
- cv2.imread(str(SOURCE)), # OpenCV
128
+ im, # OpenCV
128
129
  Image.open(SOURCE), # PIL
129
- np.zeros((320, 640, 3), dtype=np.uint8), # numpy
130
+ np.zeros((320, 640, channels), dtype=np.uint8), # numpy
130
131
  ]
131
132
  assert len(model(batch, imgsz=32, classes=0)) == len(batch) # multiple sources in a batch
132
133
 
@@ -177,14 +178,17 @@ def test_youtube():
177
178
 
178
179
  @pytest.mark.skipif(not ONLINE, reason="environment is offline")
179
180
  @pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
180
- def test_track_stream():
181
+ @pytest.mark.parametrize("model", MODELS)
182
+ def test_track_stream(model):
181
183
  """
182
184
  Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
183
185
 
184
186
  Note imgsz=160 required for tracking for higher confidence and better matches.
185
187
  """
188
+ if model == "yolo11n-cls.pt": # classification model not supported for tracking
189
+ return
186
190
  video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"
187
- model = YOLO(MODEL)
191
+ model = YOLO(model)
188
192
  model.track(video_url, imgsz=160, tracker="bytetrack.yaml")
189
193
  model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also
190
194
 
@@ -196,15 +200,23 @@ def test_track_stream():
196
200
  model.track(video_url, imgsz=160, tracker=custom_yaml)
197
201
 
198
202
 
199
- def test_val():
203
+ @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
204
+ def test_val(task: str, model: str, data: str) -> None:
200
205
  """Test the validation mode of the YOLO model."""
201
- metrics = YOLO(MODEL).val(data="coco8.yaml", imgsz=32)
202
- metrics.to_df()
203
- metrics.to_csv()
204
- metrics.to_xml()
205
- metrics.to_html()
206
- metrics.to_json()
207
- metrics.to_sql()
206
+ for plots in [True, False]: # Test both cases i.e. plots=True and plots=False
207
+ metrics = YOLO(model).val(data=data, imgsz=32, plots=plots)
208
+ metrics.to_df()
209
+ metrics.to_csv()
210
+ metrics.to_xml()
211
+ metrics.to_html()
212
+ metrics.to_json()
213
+ metrics.to_sql()
214
+ metrics.confusion_matrix.to_df() # Tests for confusion matrix export
215
+ metrics.confusion_matrix.to_csv()
216
+ metrics.confusion_matrix.to_xml()
217
+ metrics.confusion_matrix.to_html()
218
+ metrics.confusion_matrix.to_json()
219
+ metrics.confusion_matrix.to_sql()
208
220
 
209
221
 
210
222
  def test_train_scratch():
@@ -268,7 +280,7 @@ def test_predict_callback_and_setup():
268
280
 
269
281
 
270
282
  @pytest.mark.parametrize("model", MODELS)
271
- def test_results(model):
283
+ def test_results(model: str):
272
284
  """Test YOLO model results processing and output in various formats."""
273
285
  temp_s = "https://ultralytics.com/images/boats.jpg" if model == "yolo11n-obb.pt" else SOURCE
274
286
  results = YOLO(WEIGHTS_DIR / model)([temp_s, temp_s], imgsz=160)
@@ -699,3 +711,28 @@ def test_multichannel():
699
711
  im = np.zeros((32, 32, 10), dtype=np.uint8)
700
712
  model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
701
713
  model.export(format="onnx")
714
+
715
+
716
+ @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
717
+ def test_grayscale(task: str, model: str, data: str) -> None:
718
+ """Test YOLO model grayscale training, validation, and prediction functionality."""
719
+ if task == "classify": # not support grayscale classification yet
720
+ return
721
+ grayscale_data = Path(TMP) / f"{Path(data).stem}-grayscale.yaml"
722
+ data = YAML.load(checks.check_file(data))
723
+ data["channels"] = 1 # add additional channels key for grayscale
724
+ YAML.save(grayscale_data, data)
725
+ # remove npy files in train/val splits if exists, might be created by previous tests
726
+ for split in {"train", "val"}:
727
+ for npy_file in (Path(data["path"]) / data[split]).glob("*.npy"):
728
+ npy_file.unlink()
729
+
730
+ model = YOLO(model)
731
+ model.train(data=grayscale_data, epochs=1, imgsz=32, close_mosaic=1)
732
+ model.val(data=grayscale_data)
733
+ im = np.zeros((32, 32, 1), dtype=np.uint8)
734
+ model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
735
+ export_model = model.export(format="onnx")
736
+
737
+ model = YOLO(export_model, task=task)
738
+ model.predict(source=im, imgsz=32)
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.145"
3
+ __version__ = "8.3.147"
4
4
 
5
5
  import os
6
6
 
@@ -70,7 +70,6 @@ TASK2METRIC = {
70
70
  "pose": "metrics/mAP50-95(P)",
71
71
  "obb": "metrics/mAP50-95(B)",
72
72
  }
73
- MODELS = frozenset(TASK2MODEL[task] for task in TASKS)
74
73
 
75
74
  ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
76
75
  SOLUTIONS_HELP_MSG = f"""
@@ -0,0 +1,103 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # COCO8-Grayscale dataset (first 8 images from COCO train2017) by Ultralytics
4
+ # Documentation: https://docs.ultralytics.com/datasets/detect/coco8-grayscale/
5
+ # Example usage: yolo train data=coco8-grayscale.yaml
6
+ # parent
7
+ # ├── ultralytics
8
+ # └── datasets
9
+ # └── coco8-grayscale ← downloads here (1 MB)
10
+
11
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
12
+ path: ../datasets/coco8-grayscale # dataset root dir
13
+ train: images/train # train images (relative to 'path') 4 images
14
+ val: images/val # val images (relative to 'path') 4 images
15
+ test: # test images (optional)
16
+
17
+ channels: 1
18
+
19
+ # Classes
20
+ names:
21
+ 0: person
22
+ 1: bicycle
23
+ 2: car
24
+ 3: motorcycle
25
+ 4: airplane
26
+ 5: bus
27
+ 6: train
28
+ 7: truck
29
+ 8: boat
30
+ 9: traffic light
31
+ 10: fire hydrant
32
+ 11: stop sign
33
+ 12: parking meter
34
+ 13: bench
35
+ 14: bird
36
+ 15: cat
37
+ 16: dog
38
+ 17: horse
39
+ 18: sheep
40
+ 19: cow
41
+ 20: elephant
42
+ 21: bear
43
+ 22: zebra
44
+ 23: giraffe
45
+ 24: backpack
46
+ 25: umbrella
47
+ 26: handbag
48
+ 27: tie
49
+ 28: suitcase
50
+ 29: frisbee
51
+ 30: skis
52
+ 31: snowboard
53
+ 32: sports ball
54
+ 33: kite
55
+ 34: baseball bat
56
+ 35: baseball glove
57
+ 36: skateboard
58
+ 37: surfboard
59
+ 38: tennis racket
60
+ 39: bottle
61
+ 40: wine glass
62
+ 41: cup
63
+ 42: fork
64
+ 43: knife
65
+ 44: spoon
66
+ 45: bowl
67
+ 46: banana
68
+ 47: apple
69
+ 48: sandwich
70
+ 49: orange
71
+ 50: broccoli
72
+ 51: carrot
73
+ 52: hot dog
74
+ 53: pizza
75
+ 54: donut
76
+ 55: cake
77
+ 56: chair
78
+ 57: couch
79
+ 58: potted plant
80
+ 59: bed
81
+ 60: dining table
82
+ 61: toilet
83
+ 62: tv
84
+ 63: laptop
85
+ 64: mouse
86
+ 65: remote
87
+ 66: keyboard
88
+ 67: cell phone
89
+ 68: microwave
90
+ 69: oven
91
+ 70: toaster
92
+ 71: sink
93
+ 72: refrigerator
94
+ 73: book
95
+ 74: clock
96
+ 75: vase
97
+ 76: scissors
98
+ 77: teddy bear
99
+ 78: hair drier
100
+ 79: toothbrush
101
+
102
+ # Download script/URL (optional)
103
+ download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-grayscale.zip
@@ -3,7 +3,7 @@
3
3
  import math
4
4
  import random
5
5
  from copy import deepcopy
6
- from typing import List, Tuple, Union
6
+ from typing import Any, Dict, List, Tuple, Union
7
7
 
8
8
  import cv2
9
9
  import numpy as np
@@ -2416,7 +2416,7 @@ class RandomLoadText:
2416
2416
  self.padding = padding
2417
2417
  self.padding_value = padding_value
2418
2418
 
2419
- def __call__(self, labels: dict) -> dict:
2419
+ def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:
2420
2420
  """
2421
2421
  Randomly sample positive and negative texts and update class indices accordingly.
2422
2422
 
@@ -673,8 +673,8 @@ class Model(torch.nn.Module):
673
673
  custom = {"verbose": False} # method defaults
674
674
  args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
675
675
  fmts = export_formats()
676
- export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, []))
677
- export_kwargs = {k: v for k, v in args.items() if k in export_args - set(["batch"])}
676
+ export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
677
+ export_kwargs = {k: v for k, v in args.items() if k in export_args}
678
678
  return benchmark(
679
679
  model=self,
680
680
  data=data, # if no 'data' argument passed set data=None for default datasets
@@ -754,7 +754,7 @@ class Model(torch.nn.Module):
754
754
  **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
755
755
  data (str): Path to dataset configuration file.
756
756
  epochs (int): Number of training epochs.
757
- batch_size (int): Batch size for training.
757
+ batch (int): Batch size for training.
758
758
  imgsz (int): Input image size.
759
759
  device (str): Device to run training on (e.g., 'cuda', 'cpu').
760
760
  workers (int): Number of worker threads for data loading.
@@ -1033,7 +1033,7 @@ class Model(torch.nn.Module):
1033
1033
  self.callbacks[event] = [callbacks.default_callbacks[event][0]]
1034
1034
 
1035
1035
  @staticmethod
1036
- def _reset_ckpt_args(args: dict) -> dict:
1036
+ def _reset_ckpt_args(args: Dict[str, Any]) -> Dict[str, Any]:
1037
1037
  """
1038
1038
  Reset specific arguments when loading a PyTorch model checkpoint.
1039
1039
 
@@ -329,7 +329,7 @@ class BaseValidator:
329
329
  """Update metrics based on predictions and batch."""
330
330
  pass
331
331
 
332
- def finalize_metrics(self, *args, **kwargs):
332
+ def finalize_metrics(self):
333
333
  """Finalize and return all metrics."""
334
334
  pass
335
335
 
@@ -1,12 +1,4 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
- """
3
- YOLO-NAS model interface.
4
-
5
- Examples:
6
- >>> from ultralytics import NAS
7
- >>> model = NAS("yolo_nas_s")
8
- >>> results = model.predict("ultralytics/assets/bus.jpg")
9
- """
10
2
 
11
3
  from pathlib import Path
12
4
  from typing import Any, Dict
@@ -79,7 +79,9 @@ class ClassificationValidator(BaseValidator):
79
79
  """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
80
80
  self.names = model.names
81
81
  self.nc = len(model.names)
82
- self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
82
+ self.confusion_matrix = ConfusionMatrix(
83
+ nc=self.nc, conf=self.args.conf, names=self.names.values(), task="classify"
84
+ )
83
85
  self.pred = []
84
86
  self.targets = []
85
87
 
@@ -106,14 +108,10 @@ class ClassificationValidator(BaseValidator):
106
108
  self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
107
109
  self.targets.append(batch["cls"].type(torch.int32).cpu())
108
110
 
109
- def finalize_metrics(self, *args, **kwargs):
111
+ def finalize_metrics(self) -> None:
110
112
  """
111
113
  Finalize metrics including confusion matrix and processing speed.
112
114
 
113
- Args:
114
- *args (Any): Variable length argument list.
115
- **kwargs (Any): Arbitrary keyword arguments.
116
-
117
115
  Notes:
118
116
  This method processes the accumulated predictions and targets to generate the confusion matrix,
119
117
  optionally plots it, and updates the metrics object with speed information.
@@ -128,9 +126,7 @@ class ClassificationValidator(BaseValidator):
128
126
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
129
127
  if self.args.plots:
130
128
  for normalize in True, False:
131
- self.confusion_matrix.plot(
132
- save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
133
- )
129
+ self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
134
130
  self.metrics.speed = self.speed
135
131
  self.metrics.confusion_matrix = self.confusion_matrix
136
132
  self.metrics.save_dir = self.save_dir
@@ -42,7 +42,7 @@ class DetectionValidator(BaseValidator):
42
42
  >>> validator()
43
43
  """
44
44
 
45
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
45
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
46
46
  """
47
47
  Initialize detection validator with necessary variables and settings.
48
48
 
@@ -102,7 +102,7 @@ class DetectionValidator(BaseValidator):
102
102
  self.end2end = getattr(model, "end2end", False)
103
103
  self.metrics.names = self.names
104
104
  self.metrics.plot = self.args.plots
105
- self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
105
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, names=self.names.values())
106
106
  self.seen = 0
107
107
  self.jdict = []
108
108
  self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
@@ -227,14 +227,11 @@ class DetectionValidator(BaseValidator):
227
227
  self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
228
228
  )
229
229
 
230
- def finalize_metrics(self, *args: Any, **kwargs: Any) -> None:
231
- """
232
- Set final values for metrics speed and confusion matrix.
233
-
234
- Args:
235
- *args (Any): Variable length argument list.
236
- **kwargs (Any): Arbitrary keyword arguments.
237
- """
230
+ def finalize_metrics(self) -> None:
231
+ """Set final values for metrics speed and confusion matrix."""
232
+ if self.args.plots:
233
+ for normalize in True, False:
234
+ self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
238
235
  self.metrics.speed = self.speed
239
236
  self.metrics.confusion_matrix = self.confusion_matrix
240
237
 
@@ -267,12 +264,6 @@ class DetectionValidator(BaseValidator):
267
264
  pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
268
265
  )
269
266
 
270
- if self.args.plots:
271
- for normalize in True, False:
272
- self.confusion_matrix.plot(
273
- save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
274
- )
275
-
276
267
  def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
277
268
  """
278
269
  Return correct prediction matrix.
@@ -290,7 +281,7 @@ class DetectionValidator(BaseValidator):
290
281
  iou = box_iou(gt_bboxes, detections[:, :4])
291
282
  return self.match_predictions(detections[:, 5], gt_cls, iou)
292
283
 
293
- def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
284
+ def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
294
285
  """
295
286
  Build YOLO Dataset.
296
287
 
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
- from typing import Dict, List, Tuple, Union
4
+ from typing import Any, Dict, List, Tuple, Union
5
5
 
6
6
  import torch
7
7
 
@@ -40,7 +40,7 @@ class OBBValidator(DetectionValidator):
40
40
  >>> validator(model=args["model"])
41
41
  """
42
42
 
43
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
43
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
44
44
  """
45
45
  Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
46
46
 
@@ -58,8 +58,13 @@ class OBBValidator(DetectionValidator):
58
58
  self.args.task = "obb"
59
59
  self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
60
60
 
61
- def init_metrics(self, model):
62
- """Initialize evaluation metrics for YOLO."""
61
+ def init_metrics(self, model: torch.nn.Module) -> None:
62
+ """
63
+ Initialize evaluation metrics for YOLO obb validation.
64
+
65
+ Args:
66
+ model (torch.nn.Module): Model to validate.
67
+ """
63
68
  super().init_metrics(model)
64
69
  val = self.data.get(self.args.split, "") # validation path
65
70
  self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
@@ -94,7 +99,7 @@ class OBBValidator(DetectionValidator):
94
99
 
95
100
  Args:
96
101
  si (int): Batch index to process.
97
- batch (dict): Dictionary containing batch data with keys:
102
+ batch (Dict[str, Any]): Dictionary containing batch data with keys:
98
103
  - batch_idx: Tensor of batch indices
99
104
  - cls: Tensor of class labels
100
105
  - bboxes: Tensor of bounding boxes
@@ -103,7 +108,7 @@ class OBBValidator(DetectionValidator):
103
108
  - ratio_pad: Ratio and padding information
104
109
 
105
110
  Returns:
106
- (dict): Prepared batch data with scaled bounding boxes and metadata.
111
+ (Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
107
112
  """
108
113
  idx = batch["batch_idx"] == si
109
114
  cls = batch["cls"][idx].squeeze(-1)
@@ -116,7 +121,7 @@ class OBBValidator(DetectionValidator):
116
121
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
117
122
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
118
123
 
119
- def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
124
+ def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
120
125
  """
121
126
  Prepare predictions by scaling bounding boxes to original image dimensions.
122
127
 
@@ -125,7 +130,7 @@ class OBBValidator(DetectionValidator):
125
130
 
126
131
  Args:
127
132
  pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
128
- pbatch (dict): Dictionary containing batch information with keys:
133
+ pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
129
134
  - imgsz (tuple): Model input image size.
130
135
  - ori_shape (tuple): Original image shape.
131
136
  - ratio_pad (tuple): Ratio and padding information for scaling.
@@ -139,13 +144,13 @@ class OBBValidator(DetectionValidator):
139
144
  ) # native-space pred
140
145
  return predn
141
146
 
142
- def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
147
+ def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
143
148
  """
144
149
  Plot predicted bounding boxes on input images and save the result.
145
150
 
146
151
  Args:
147
- batch (dict): Batch data containing images, file paths, and other metadata.
148
- preds (list): List of prediction tensors for each image in the batch.
152
+ batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.
153
+ preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.
149
154
  ni (int): Batch index used for naming the output file.
150
155
 
151
156
  Examples:
@@ -163,7 +168,7 @@ class OBBValidator(DetectionValidator):
163
168
  on_plot=self.on_plot,
164
169
  ) # pred
165
170
 
166
- def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
171
+ def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
167
172
  """
168
173
  Convert YOLO predictions to COCO JSON format with rotated bounding box information.
169
174
 
@@ -192,7 +197,9 @@ class OBBValidator(DetectionValidator):
192
197
  }
193
198
  )
194
199
 
195
- def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]):
200
+ def save_one_txt(
201
+ self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
202
+ ) -> None:
196
203
  """
197
204
  Save YOLO OBB detections to a text file in normalized coordinates.
198
205
 
@@ -200,7 +207,7 @@ class OBBValidator(DetectionValidator):
200
207
  predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
201
208
  class predictions, and angles in format (x, y, w, h, conf, cls, angle).
202
209
  save_conf (bool): Whether to save confidence scores in the text file.
203
- shape (tuple): Original image shape in format (height, width).
210
+ shape (Tuple[int, int]): Original image shape in format (height, width).
204
211
  file (Path | str): Output file path to save detections.
205
212
 
206
213
  Examples:
@@ -222,15 +229,15 @@ class OBBValidator(DetectionValidator):
222
229
  obb=obb,
223
230
  ).save_txt(file, save_conf=save_conf)
224
231
 
225
- def eval_json(self, stats: Dict) -> Dict:
232
+ def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
226
233
  """
227
234
  Evaluate YOLO output in JSON format and save predictions in DOTA format.
228
235
 
229
236
  Args:
230
- stats (dict): Performance statistics dictionary.
237
+ stats (Dict[str, Any]): Performance statistics dictionary.
231
238
 
232
239
  Returns:
233
- (dict): Updated performance statistics.
240
+ (Dict[str, Any]): Updated performance statistics.
234
241
  """
235
242
  if self.args.save_json and self.is_dota and len(self.jdict):
236
243
  import json