ultralytics 8.3.145__py3-none-any.whl → 8.3.147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +3 -0
- tests/test_cli.py +2 -7
- tests/test_python.py +55 -18
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +0 -1
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/data/augment.py +2 -2
- ultralytics/engine/model.py +4 -4
- ultralytics/engine/validator.py +1 -1
- ultralytics/models/nas/model.py +0 -8
- ultralytics/models/yolo/classify/val.py +5 -9
- ultralytics/models/yolo/detect/val.py +8 -17
- ultralytics/models/yolo/obb/val.py +24 -17
- ultralytics/models/yolo/pose/val.py +19 -14
- ultralytics/models/yolo/segment/val.py +52 -44
- ultralytics/nn/tasks.py +3 -0
- ultralytics/solutions/analytics.py +17 -9
- ultralytics/solutions/object_counter.py +2 -4
- ultralytics/trackers/bot_sort.py +4 -2
- ultralytics/utils/__init__.py +2 -3
- ultralytics/utils/benchmarks.py +15 -15
- ultralytics/utils/checks.py +10 -5
- ultralytics/utils/downloads.py +1 -0
- ultralytics/utils/metrics.py +52 -33
- ultralytics/utils/plotting.py +10 -7
- ultralytics/utils/torch_utils.py +2 -2
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/METADATA +1 -1
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/RECORD +32 -31
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/WHEEL +1 -1
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.147.dist-info}/top_level.txt +0 -0
tests/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
3
4
|
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks
|
4
5
|
|
5
6
|
# Constants used in tests
|
@@ -10,6 +11,8 @@ SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
|
|
10
11
|
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
|
11
12
|
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
12
13
|
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
14
|
+
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
+
MODELS = frozenset(list(TASK2MODEL.values()) + ["yolo11n-grayscale.pt"])
|
13
16
|
|
14
17
|
__all__ = (
|
15
18
|
"MODEL",
|
tests/test_cli.py
CHANGED
@@ -5,15 +5,10 @@ import subprocess
|
|
5
5
|
import pytest
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
-
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE
|
9
|
-
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
8
|
+
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA
|
10
9
|
from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks
|
11
10
|
from ultralytics.utils.torch_utils import TORCH_1_9
|
12
11
|
|
13
|
-
# Constants
|
14
|
-
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
-
MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS]
|
16
|
-
|
17
12
|
|
18
13
|
def run(cmd: str) -> None:
|
19
14
|
"""Execute a shell command using subprocess."""
|
@@ -44,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
|
|
44
39
|
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
45
40
|
def test_predict(task: str, model: str, data: str) -> None:
|
46
41
|
"""Test YOLO prediction on provided sample assets for specified task and model."""
|
47
|
-
run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
42
|
+
run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
48
43
|
|
49
44
|
|
50
45
|
@pytest.mark.parametrize("model", MODELS)
|
tests/test_python.py
CHANGED
@@ -12,9 +12,9 @@ import pytest
|
|
12
12
|
import torch
|
13
13
|
from PIL import Image
|
14
14
|
|
15
|
-
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
|
15
|
+
from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP
|
16
16
|
from ultralytics import RTDETR, YOLO
|
17
|
-
from ultralytics.cfg import
|
17
|
+
from ultralytics.cfg import TASK2DATA, TASKS
|
18
18
|
from ultralytics.data.build import load_inference_source
|
19
19
|
from ultralytics.utils import (
|
20
20
|
ARM64,
|
@@ -112,21 +112,22 @@ def test_predict_csv_single_row():
|
|
112
112
|
@pytest.mark.parametrize("model_name", MODELS)
|
113
113
|
def test_predict_img(model_name):
|
114
114
|
"""Test YOLO model predictions on various image input types and sources, including online images."""
|
115
|
+
channels = 1 if model_name == "yolo11n-grayscale.pt" else 3
|
115
116
|
model = YOLO(WEIGHTS_DIR / model_name)
|
116
|
-
im = cv2.imread(str(SOURCE)) # uint8 numpy array
|
117
|
+
im = cv2.imread(str(SOURCE), flags=cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR) # uint8 numpy array
|
117
118
|
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
|
118
119
|
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
|
119
|
-
assert len(model(torch.rand((2,
|
120
|
+
assert len(model(torch.rand((2, channels, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
|
120
121
|
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
|
121
122
|
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
|
122
|
-
assert len(model(torch.zeros(320, 640,
|
123
|
+
assert len(model(torch.zeros(320, 640, channels).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
|
123
124
|
batch = [
|
124
125
|
str(SOURCE), # filename
|
125
126
|
Path(SOURCE), # Path
|
126
127
|
"https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE, # URI
|
127
|
-
|
128
|
+
im, # OpenCV
|
128
129
|
Image.open(SOURCE), # PIL
|
129
|
-
np.zeros((320, 640,
|
130
|
+
np.zeros((320, 640, channels), dtype=np.uint8), # numpy
|
130
131
|
]
|
131
132
|
assert len(model(batch, imgsz=32, classes=0)) == len(batch) # multiple sources in a batch
|
132
133
|
|
@@ -177,14 +178,17 @@ def test_youtube():
|
|
177
178
|
|
178
179
|
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
|
179
180
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
|
180
|
-
|
181
|
+
@pytest.mark.parametrize("model", MODELS)
|
182
|
+
def test_track_stream(model):
|
181
183
|
"""
|
182
184
|
Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
|
183
185
|
|
184
186
|
Note imgsz=160 required for tracking for higher confidence and better matches.
|
185
187
|
"""
|
188
|
+
if model == "yolo11n-cls.pt": # classification model not supported for tracking
|
189
|
+
return
|
186
190
|
video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"
|
187
|
-
model = YOLO(
|
191
|
+
model = YOLO(model)
|
188
192
|
model.track(video_url, imgsz=160, tracker="bytetrack.yaml")
|
189
193
|
model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also
|
190
194
|
|
@@ -196,15 +200,23 @@ def test_track_stream():
|
|
196
200
|
model.track(video_url, imgsz=160, tracker=custom_yaml)
|
197
201
|
|
198
202
|
|
199
|
-
|
203
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
204
|
+
def test_val(task: str, model: str, data: str) -> None:
|
200
205
|
"""Test the validation mode of the YOLO model."""
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
206
|
+
for plots in [True, False]: # Test both cases i.e. plots=True and plots=False
|
207
|
+
metrics = YOLO(model).val(data=data, imgsz=32, plots=plots)
|
208
|
+
metrics.to_df()
|
209
|
+
metrics.to_csv()
|
210
|
+
metrics.to_xml()
|
211
|
+
metrics.to_html()
|
212
|
+
metrics.to_json()
|
213
|
+
metrics.to_sql()
|
214
|
+
metrics.confusion_matrix.to_df() # Tests for confusion matrix export
|
215
|
+
metrics.confusion_matrix.to_csv()
|
216
|
+
metrics.confusion_matrix.to_xml()
|
217
|
+
metrics.confusion_matrix.to_html()
|
218
|
+
metrics.confusion_matrix.to_json()
|
219
|
+
metrics.confusion_matrix.to_sql()
|
208
220
|
|
209
221
|
|
210
222
|
def test_train_scratch():
|
@@ -268,7 +280,7 @@ def test_predict_callback_and_setup():
|
|
268
280
|
|
269
281
|
|
270
282
|
@pytest.mark.parametrize("model", MODELS)
|
271
|
-
def test_results(model):
|
283
|
+
def test_results(model: str):
|
272
284
|
"""Test YOLO model results processing and output in various formats."""
|
273
285
|
temp_s = "https://ultralytics.com/images/boats.jpg" if model == "yolo11n-obb.pt" else SOURCE
|
274
286
|
results = YOLO(WEIGHTS_DIR / model)([temp_s, temp_s], imgsz=160)
|
@@ -699,3 +711,28 @@ def test_multichannel():
|
|
699
711
|
im = np.zeros((32, 32, 10), dtype=np.uint8)
|
700
712
|
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
701
713
|
model.export(format="onnx")
|
714
|
+
|
715
|
+
|
716
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
717
|
+
def test_grayscale(task: str, model: str, data: str) -> None:
|
718
|
+
"""Test YOLO model grayscale training, validation, and prediction functionality."""
|
719
|
+
if task == "classify": # not support grayscale classification yet
|
720
|
+
return
|
721
|
+
grayscale_data = Path(TMP) / f"{Path(data).stem}-grayscale.yaml"
|
722
|
+
data = YAML.load(checks.check_file(data))
|
723
|
+
data["channels"] = 1 # add additional channels key for grayscale
|
724
|
+
YAML.save(grayscale_data, data)
|
725
|
+
# remove npy files in train/val splits if exists, might be created by previous tests
|
726
|
+
for split in {"train", "val"}:
|
727
|
+
for npy_file in (Path(data["path"]) / data[split]).glob("*.npy"):
|
728
|
+
npy_file.unlink()
|
729
|
+
|
730
|
+
model = YOLO(model)
|
731
|
+
model.train(data=grayscale_data, epochs=1, imgsz=32, close_mosaic=1)
|
732
|
+
model.val(data=grayscale_data)
|
733
|
+
im = np.zeros((32, 32, 1), dtype=np.uint8)
|
734
|
+
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
735
|
+
export_model = model.export(format="onnx")
|
736
|
+
|
737
|
+
model = YOLO(export_model, task=task)
|
738
|
+
model.predict(source=im, imgsz=32)
|
ultralytics/__init__.py
CHANGED
ultralytics/cfg/__init__.py
CHANGED
@@ -0,0 +1,103 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# COCO8-Grayscale dataset (first 8 images from COCO train2017) by Ultralytics
|
4
|
+
# Documentation: https://docs.ultralytics.com/datasets/detect/coco8-grayscale/
|
5
|
+
# Example usage: yolo train data=coco8-grayscale.yaml
|
6
|
+
# parent
|
7
|
+
# ├── ultralytics
|
8
|
+
# └── datasets
|
9
|
+
# └── coco8-grayscale ← downloads here (1 MB)
|
10
|
+
|
11
|
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12
|
+
path: ../datasets/coco8-grayscale # dataset root dir
|
13
|
+
train: images/train # train images (relative to 'path') 4 images
|
14
|
+
val: images/val # val images (relative to 'path') 4 images
|
15
|
+
test: # test images (optional)
|
16
|
+
|
17
|
+
channels: 1
|
18
|
+
|
19
|
+
# Classes
|
20
|
+
names:
|
21
|
+
0: person
|
22
|
+
1: bicycle
|
23
|
+
2: car
|
24
|
+
3: motorcycle
|
25
|
+
4: airplane
|
26
|
+
5: bus
|
27
|
+
6: train
|
28
|
+
7: truck
|
29
|
+
8: boat
|
30
|
+
9: traffic light
|
31
|
+
10: fire hydrant
|
32
|
+
11: stop sign
|
33
|
+
12: parking meter
|
34
|
+
13: bench
|
35
|
+
14: bird
|
36
|
+
15: cat
|
37
|
+
16: dog
|
38
|
+
17: horse
|
39
|
+
18: sheep
|
40
|
+
19: cow
|
41
|
+
20: elephant
|
42
|
+
21: bear
|
43
|
+
22: zebra
|
44
|
+
23: giraffe
|
45
|
+
24: backpack
|
46
|
+
25: umbrella
|
47
|
+
26: handbag
|
48
|
+
27: tie
|
49
|
+
28: suitcase
|
50
|
+
29: frisbee
|
51
|
+
30: skis
|
52
|
+
31: snowboard
|
53
|
+
32: sports ball
|
54
|
+
33: kite
|
55
|
+
34: baseball bat
|
56
|
+
35: baseball glove
|
57
|
+
36: skateboard
|
58
|
+
37: surfboard
|
59
|
+
38: tennis racket
|
60
|
+
39: bottle
|
61
|
+
40: wine glass
|
62
|
+
41: cup
|
63
|
+
42: fork
|
64
|
+
43: knife
|
65
|
+
44: spoon
|
66
|
+
45: bowl
|
67
|
+
46: banana
|
68
|
+
47: apple
|
69
|
+
48: sandwich
|
70
|
+
49: orange
|
71
|
+
50: broccoli
|
72
|
+
51: carrot
|
73
|
+
52: hot dog
|
74
|
+
53: pizza
|
75
|
+
54: donut
|
76
|
+
55: cake
|
77
|
+
56: chair
|
78
|
+
57: couch
|
79
|
+
58: potted plant
|
80
|
+
59: bed
|
81
|
+
60: dining table
|
82
|
+
61: toilet
|
83
|
+
62: tv
|
84
|
+
63: laptop
|
85
|
+
64: mouse
|
86
|
+
65: remote
|
87
|
+
66: keyboard
|
88
|
+
67: cell phone
|
89
|
+
68: microwave
|
90
|
+
69: oven
|
91
|
+
70: toaster
|
92
|
+
71: sink
|
93
|
+
72: refrigerator
|
94
|
+
73: book
|
95
|
+
74: clock
|
96
|
+
75: vase
|
97
|
+
76: scissors
|
98
|
+
77: teddy bear
|
99
|
+
78: hair drier
|
100
|
+
79: toothbrush
|
101
|
+
|
102
|
+
# Download script/URL (optional)
|
103
|
+
download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-grayscale.zip
|
ultralytics/data/augment.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import cv2
|
9
9
|
import numpy as np
|
@@ -2416,7 +2416,7 @@ class RandomLoadText:
|
|
2416
2416
|
self.padding = padding
|
2417
2417
|
self.padding_value = padding_value
|
2418
2418
|
|
2419
|
-
def __call__(self, labels:
|
2419
|
+
def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:
|
2420
2420
|
"""
|
2421
2421
|
Randomly sample positive and negative texts and update class indices accordingly.
|
2422
2422
|
|
ultralytics/engine/model.py
CHANGED
@@ -673,8 +673,8 @@ class Model(torch.nn.Module):
|
|
673
673
|
custom = {"verbose": False} # method defaults
|
674
674
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
675
675
|
fmts = export_formats()
|
676
|
-
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, []))
|
677
|
-
export_kwargs = {k: v for k, v in args.items() if k in export_args
|
676
|
+
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
|
677
|
+
export_kwargs = {k: v for k, v in args.items() if k in export_args}
|
678
678
|
return benchmark(
|
679
679
|
model=self,
|
680
680
|
data=data, # if no 'data' argument passed set data=None for default datasets
|
@@ -754,7 +754,7 @@ class Model(torch.nn.Module):
|
|
754
754
|
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
755
755
|
data (str): Path to dataset configuration file.
|
756
756
|
epochs (int): Number of training epochs.
|
757
|
-
|
757
|
+
batch (int): Batch size for training.
|
758
758
|
imgsz (int): Input image size.
|
759
759
|
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
760
760
|
workers (int): Number of worker threads for data loading.
|
@@ -1033,7 +1033,7 @@ class Model(torch.nn.Module):
|
|
1033
1033
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
1034
1034
|
|
1035
1035
|
@staticmethod
|
1036
|
-
def _reset_ckpt_args(args:
|
1036
|
+
def _reset_ckpt_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
1037
1037
|
"""
|
1038
1038
|
Reset specific arguments when loading a PyTorch model checkpoint.
|
1039
1039
|
|
ultralytics/engine/validator.py
CHANGED
ultralytics/models/nas/model.py
CHANGED
@@ -1,12 +1,4 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
-
"""
|
3
|
-
YOLO-NAS model interface.
|
4
|
-
|
5
|
-
Examples:
|
6
|
-
>>> from ultralytics import NAS
|
7
|
-
>>> model = NAS("yolo_nas_s")
|
8
|
-
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
9
|
-
"""
|
10
2
|
|
11
3
|
from pathlib import Path
|
12
4
|
from typing import Any, Dict
|
@@ -79,7 +79,9 @@ class ClassificationValidator(BaseValidator):
|
|
79
79
|
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
80
80
|
self.names = model.names
|
81
81
|
self.nc = len(model.names)
|
82
|
-
self.confusion_matrix = ConfusionMatrix(
|
82
|
+
self.confusion_matrix = ConfusionMatrix(
|
83
|
+
nc=self.nc, conf=self.args.conf, names=self.names.values(), task="classify"
|
84
|
+
)
|
83
85
|
self.pred = []
|
84
86
|
self.targets = []
|
85
87
|
|
@@ -106,14 +108,10 @@ class ClassificationValidator(BaseValidator):
|
|
106
108
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
107
109
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
108
110
|
|
109
|
-
def finalize_metrics(self
|
111
|
+
def finalize_metrics(self) -> None:
|
110
112
|
"""
|
111
113
|
Finalize metrics including confusion matrix and processing speed.
|
112
114
|
|
113
|
-
Args:
|
114
|
-
*args (Any): Variable length argument list.
|
115
|
-
**kwargs (Any): Arbitrary keyword arguments.
|
116
|
-
|
117
115
|
Notes:
|
118
116
|
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
119
117
|
optionally plots it, and updates the metrics object with speed information.
|
@@ -128,9 +126,7 @@ class ClassificationValidator(BaseValidator):
|
|
128
126
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
129
127
|
if self.args.plots:
|
130
128
|
for normalize in True, False:
|
131
|
-
self.confusion_matrix.plot(
|
132
|
-
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
133
|
-
)
|
129
|
+
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
134
130
|
self.metrics.speed = self.speed
|
135
131
|
self.metrics.confusion_matrix = self.confusion_matrix
|
136
132
|
self.metrics.save_dir = self.save_dir
|
@@ -42,7 +42,7 @@ class DetectionValidator(BaseValidator):
|
|
42
42
|
>>> validator()
|
43
43
|
"""
|
44
44
|
|
45
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
45
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
46
46
|
"""
|
47
47
|
Initialize detection validator with necessary variables and settings.
|
48
48
|
|
@@ -102,7 +102,7 @@ class DetectionValidator(BaseValidator):
|
|
102
102
|
self.end2end = getattr(model, "end2end", False)
|
103
103
|
self.metrics.names = self.names
|
104
104
|
self.metrics.plot = self.args.plots
|
105
|
-
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
|
105
|
+
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, names=self.names.values())
|
106
106
|
self.seen = 0
|
107
107
|
self.jdict = []
|
108
108
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
@@ -227,14 +227,11 @@ class DetectionValidator(BaseValidator):
|
|
227
227
|
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
228
228
|
)
|
229
229
|
|
230
|
-
def finalize_metrics(self
|
231
|
-
"""
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
*args (Any): Variable length argument list.
|
236
|
-
**kwargs (Any): Arbitrary keyword arguments.
|
237
|
-
"""
|
230
|
+
def finalize_metrics(self) -> None:
|
231
|
+
"""Set final values for metrics speed and confusion matrix."""
|
232
|
+
if self.args.plots:
|
233
|
+
for normalize in True, False:
|
234
|
+
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
238
235
|
self.metrics.speed = self.speed
|
239
236
|
self.metrics.confusion_matrix = self.confusion_matrix
|
240
237
|
|
@@ -267,12 +264,6 @@ class DetectionValidator(BaseValidator):
|
|
267
264
|
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
|
268
265
|
)
|
269
266
|
|
270
|
-
if self.args.plots:
|
271
|
-
for normalize in True, False:
|
272
|
-
self.confusion_matrix.plot(
|
273
|
-
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
274
|
-
)
|
275
|
-
|
276
267
|
def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
|
277
268
|
"""
|
278
269
|
Return correct prediction matrix.
|
@@ -290,7 +281,7 @@ class DetectionValidator(BaseValidator):
|
|
290
281
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
291
282
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
292
283
|
|
293
|
-
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
|
284
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
|
294
285
|
"""
|
295
286
|
Build YOLO Dataset.
|
296
287
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Dict, List, Tuple, Union
|
4
|
+
from typing import Any, Dict, List, Tuple, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -40,7 +40,7 @@ class OBBValidator(DetectionValidator):
|
|
40
40
|
>>> validator(model=args["model"])
|
41
41
|
"""
|
42
42
|
|
43
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
43
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
44
44
|
"""
|
45
45
|
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
46
46
|
|
@@ -58,8 +58,13 @@ class OBBValidator(DetectionValidator):
|
|
58
58
|
self.args.task = "obb"
|
59
59
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
|
60
60
|
|
61
|
-
def init_metrics(self, model):
|
62
|
-
"""
|
61
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
62
|
+
"""
|
63
|
+
Initialize evaluation metrics for YOLO obb validation.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model (torch.nn.Module): Model to validate.
|
67
|
+
"""
|
63
68
|
super().init_metrics(model)
|
64
69
|
val = self.data.get(self.args.split, "") # validation path
|
65
70
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
@@ -94,7 +99,7 @@ class OBBValidator(DetectionValidator):
|
|
94
99
|
|
95
100
|
Args:
|
96
101
|
si (int): Batch index to process.
|
97
|
-
batch (
|
102
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys:
|
98
103
|
- batch_idx: Tensor of batch indices
|
99
104
|
- cls: Tensor of class labels
|
100
105
|
- bboxes: Tensor of bounding boxes
|
@@ -103,7 +108,7 @@ class OBBValidator(DetectionValidator):
|
|
103
108
|
- ratio_pad: Ratio and padding information
|
104
109
|
|
105
110
|
Returns:
|
106
|
-
(
|
111
|
+
(Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
107
112
|
"""
|
108
113
|
idx = batch["batch_idx"] == si
|
109
114
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -116,7 +121,7 @@ class OBBValidator(DetectionValidator):
|
|
116
121
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
117
122
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
118
123
|
|
119
|
-
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
|
124
|
+
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
|
120
125
|
"""
|
121
126
|
Prepare predictions by scaling bounding boxes to original image dimensions.
|
122
127
|
|
@@ -125,7 +130,7 @@ class OBBValidator(DetectionValidator):
|
|
125
130
|
|
126
131
|
Args:
|
127
132
|
pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
|
128
|
-
pbatch (
|
133
|
+
pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
|
129
134
|
- imgsz (tuple): Model input image size.
|
130
135
|
- ori_shape (tuple): Original image shape.
|
131
136
|
- ratio_pad (tuple): Ratio and padding information for scaling.
|
@@ -139,13 +144,13 @@ class OBBValidator(DetectionValidator):
|
|
139
144
|
) # native-space pred
|
140
145
|
return predn
|
141
146
|
|
142
|
-
def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
|
147
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
143
148
|
"""
|
144
149
|
Plot predicted bounding boxes on input images and save the result.
|
145
150
|
|
146
151
|
Args:
|
147
|
-
batch (
|
148
|
-
preds (
|
152
|
+
batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
153
|
+
preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.
|
149
154
|
ni (int): Batch index used for naming the output file.
|
150
155
|
|
151
156
|
Examples:
|
@@ -163,7 +168,7 @@ class OBBValidator(DetectionValidator):
|
|
163
168
|
on_plot=self.on_plot,
|
164
169
|
) # pred
|
165
170
|
|
166
|
-
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
|
171
|
+
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
|
167
172
|
"""
|
168
173
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
169
174
|
|
@@ -192,7 +197,9 @@ class OBBValidator(DetectionValidator):
|
|
192
197
|
}
|
193
198
|
)
|
194
199
|
|
195
|
-
def save_one_txt(
|
200
|
+
def save_one_txt(
|
201
|
+
self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
|
202
|
+
) -> None:
|
196
203
|
"""
|
197
204
|
Save YOLO OBB detections to a text file in normalized coordinates.
|
198
205
|
|
@@ -200,7 +207,7 @@ class OBBValidator(DetectionValidator):
|
|
200
207
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
201
208
|
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
202
209
|
save_conf (bool): Whether to save confidence scores in the text file.
|
203
|
-
shape (
|
210
|
+
shape (Tuple[int, int]): Original image shape in format (height, width).
|
204
211
|
file (Path | str): Output file path to save detections.
|
205
212
|
|
206
213
|
Examples:
|
@@ -222,15 +229,15 @@ class OBBValidator(DetectionValidator):
|
|
222
229
|
obb=obb,
|
223
230
|
).save_txt(file, save_conf=save_conf)
|
224
231
|
|
225
|
-
def eval_json(self, stats: Dict) -> Dict:
|
232
|
+
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
226
233
|
"""
|
227
234
|
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
228
235
|
|
229
236
|
Args:
|
230
|
-
stats (
|
237
|
+
stats (Dict[str, Any]): Performance statistics dictionary.
|
231
238
|
|
232
239
|
Returns:
|
233
|
-
(
|
240
|
+
(Dict[str, Any]): Updated performance statistics.
|
234
241
|
"""
|
235
242
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
236
243
|
import json
|