ultralytics 8.3.145__py3-none-any.whl → 8.3.146__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +3 -0
- tests/test_cli.py +2 -7
- tests/test_python.py +42 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +0 -1
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/data/augment.py +2 -2
- ultralytics/engine/model.py +3 -3
- ultralytics/engine/validator.py +1 -1
- ultralytics/models/nas/model.py +0 -8
- ultralytics/models/yolo/classify/val.py +1 -5
- ultralytics/models/yolo/detect/val.py +9 -16
- ultralytics/models/yolo/obb/val.py +24 -17
- ultralytics/models/yolo/pose/val.py +19 -14
- ultralytics/models/yolo/segment/val.py +52 -44
- ultralytics/solutions/analytics.py +17 -9
- ultralytics/solutions/object_counter.py +2 -4
- ultralytics/trackers/bot_sort.py +4 -2
- ultralytics/utils/__init__.py +1 -2
- ultralytics/utils/benchmarks.py +15 -15
- ultralytics/utils/checks.py +10 -5
- ultralytics/utils/downloads.py +1 -0
- ultralytics/utils/metrics.py +25 -26
- ultralytics/utils/plotting.py +10 -7
- ultralytics/utils/torch_utils.py +2 -2
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/METADATA +1 -1
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/RECORD +31 -30
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/WHEEL +1 -1
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.145.dist-info → ultralytics-8.3.146.dist-info}/top_level.txt +0 -0
tests/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
3
4
|
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks
|
4
5
|
|
5
6
|
# Constants used in tests
|
@@ -10,6 +11,8 @@ SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
|
|
10
11
|
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
|
11
12
|
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
12
13
|
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
14
|
+
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
+
MODELS = frozenset(list(TASK2MODEL.values()) + ["yolo11n-grayscale.pt"])
|
13
16
|
|
14
17
|
__all__ = (
|
15
18
|
"MODEL",
|
tests/test_cli.py
CHANGED
@@ -5,15 +5,10 @@ import subprocess
|
|
5
5
|
import pytest
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
-
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE
|
9
|
-
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
8
|
+
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA
|
10
9
|
from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks
|
11
10
|
from ultralytics.utils.torch_utils import TORCH_1_9
|
12
11
|
|
13
|
-
# Constants
|
14
|
-
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
-
MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS]
|
16
|
-
|
17
12
|
|
18
13
|
def run(cmd: str) -> None:
|
19
14
|
"""Execute a shell command using subprocess."""
|
@@ -44,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
|
|
44
39
|
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
45
40
|
def test_predict(task: str, model: str, data: str) -> None:
|
46
41
|
"""Test YOLO prediction on provided sample assets for specified task and model."""
|
47
|
-
run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
42
|
+
run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
48
43
|
|
49
44
|
|
50
45
|
@pytest.mark.parametrize("model", MODELS)
|
tests/test_python.py
CHANGED
@@ -12,9 +12,9 @@ import pytest
|
|
12
12
|
import torch
|
13
13
|
from PIL import Image
|
14
14
|
|
15
|
-
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
|
15
|
+
from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP
|
16
16
|
from ultralytics import RTDETR, YOLO
|
17
|
-
from ultralytics.cfg import
|
17
|
+
from ultralytics.cfg import TASK2DATA, TASKS
|
18
18
|
from ultralytics.data.build import load_inference_source
|
19
19
|
from ultralytics.utils import (
|
20
20
|
ARM64,
|
@@ -112,21 +112,22 @@ def test_predict_csv_single_row():
|
|
112
112
|
@pytest.mark.parametrize("model_name", MODELS)
|
113
113
|
def test_predict_img(model_name):
|
114
114
|
"""Test YOLO model predictions on various image input types and sources, including online images."""
|
115
|
+
channels = 1 if model_name == "yolo11n-grayscale.pt" else 3
|
115
116
|
model = YOLO(WEIGHTS_DIR / model_name)
|
116
|
-
im = cv2.imread(str(SOURCE)) # uint8 numpy array
|
117
|
+
im = cv2.imread(str(SOURCE), flags=cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR) # uint8 numpy array
|
117
118
|
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
|
118
119
|
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
|
119
|
-
assert len(model(torch.rand((2,
|
120
|
+
assert len(model(torch.rand((2, channels, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
|
120
121
|
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
|
121
122
|
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
|
122
|
-
assert len(model(torch.zeros(320, 640,
|
123
|
+
assert len(model(torch.zeros(320, 640, channels).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
|
123
124
|
batch = [
|
124
125
|
str(SOURCE), # filename
|
125
126
|
Path(SOURCE), # Path
|
126
127
|
"https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE, # URI
|
127
|
-
|
128
|
+
im, # OpenCV
|
128
129
|
Image.open(SOURCE), # PIL
|
129
|
-
np.zeros((320, 640,
|
130
|
+
np.zeros((320, 640, channels), dtype=np.uint8), # numpy
|
130
131
|
]
|
131
132
|
assert len(model(batch, imgsz=32, classes=0)) == len(batch) # multiple sources in a batch
|
132
133
|
|
@@ -177,14 +178,17 @@ def test_youtube():
|
|
177
178
|
|
178
179
|
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
|
179
180
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
|
180
|
-
|
181
|
+
@pytest.mark.parametrize("model", MODELS)
|
182
|
+
def test_track_stream(model):
|
181
183
|
"""
|
182
184
|
Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
|
183
185
|
|
184
186
|
Note imgsz=160 required for tracking for higher confidence and better matches.
|
185
187
|
"""
|
188
|
+
if model == "yolo11n-cls.pt": # classification model not supported for tracking
|
189
|
+
return
|
186
190
|
video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"
|
187
|
-
model = YOLO(
|
191
|
+
model = YOLO(model)
|
188
192
|
model.track(video_url, imgsz=160, tracker="bytetrack.yaml")
|
189
193
|
model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also
|
190
194
|
|
@@ -196,9 +200,10 @@ def test_track_stream():
|
|
196
200
|
model.track(video_url, imgsz=160, tracker=custom_yaml)
|
197
201
|
|
198
202
|
|
199
|
-
|
203
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
204
|
+
def test_val(task: str, model: str, data: str) -> None:
|
200
205
|
"""Test the validation mode of the YOLO model."""
|
201
|
-
metrics = YOLO(
|
206
|
+
metrics = YOLO(model).val(data=data, imgsz=32)
|
202
207
|
metrics.to_df()
|
203
208
|
metrics.to_csv()
|
204
209
|
metrics.to_xml()
|
@@ -268,7 +273,7 @@ def test_predict_callback_and_setup():
|
|
268
273
|
|
269
274
|
|
270
275
|
@pytest.mark.parametrize("model", MODELS)
|
271
|
-
def test_results(model):
|
276
|
+
def test_results(model: str):
|
272
277
|
"""Test YOLO model results processing and output in various formats."""
|
273
278
|
temp_s = "https://ultralytics.com/images/boats.jpg" if model == "yolo11n-obb.pt" else SOURCE
|
274
279
|
results = YOLO(WEIGHTS_DIR / model)([temp_s, temp_s], imgsz=160)
|
@@ -699,3 +704,28 @@ def test_multichannel():
|
|
699
704
|
im = np.zeros((32, 32, 10), dtype=np.uint8)
|
700
705
|
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
701
706
|
model.export(format="onnx")
|
707
|
+
|
708
|
+
|
709
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
710
|
+
def test_grayscale(task: str, model: str, data: str) -> None:
|
711
|
+
"""Test YOLO model grayscale training, validation, and prediction functionality."""
|
712
|
+
if task == "classify": # not support grayscale classification yet
|
713
|
+
return
|
714
|
+
grayscale_data = Path(TMP) / f"{Path(data).stem}-grayscale.yaml"
|
715
|
+
data = YAML.load(checks.check_file(data))
|
716
|
+
data["channels"] = 1 # add additional channels key for grayscale
|
717
|
+
YAML.save(grayscale_data, data)
|
718
|
+
# remove npy files in train/val splits if exists, might be created by previous tests
|
719
|
+
for split in {"train", "val"}:
|
720
|
+
for npy_file in (Path(data["path"]) / data[split]).glob("*.npy"):
|
721
|
+
npy_file.unlink()
|
722
|
+
|
723
|
+
model = YOLO(model)
|
724
|
+
model.train(data=grayscale_data, epochs=1, imgsz=32, close_mosaic=1)
|
725
|
+
model.val(data=grayscale_data)
|
726
|
+
im = np.zeros((32, 32, 1), dtype=np.uint8)
|
727
|
+
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
728
|
+
export_model = model.export(format="onnx")
|
729
|
+
|
730
|
+
model = YOLO(export_model, task=task)
|
731
|
+
model.predict(source=im, imgsz=32)
|
ultralytics/__init__.py
CHANGED
ultralytics/cfg/__init__.py
CHANGED
@@ -0,0 +1,103 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# COCO8-Grayscale dataset (first 8 images from COCO train2017) by Ultralytics
|
4
|
+
# Documentation: https://docs.ultralytics.com/datasets/detect/coco8-grayscale/
|
5
|
+
# Example usage: yolo train data=coco8-grayscale.yaml
|
6
|
+
# parent
|
7
|
+
# ├── ultralytics
|
8
|
+
# └── datasets
|
9
|
+
# └── coco8-grayscale ← downloads here (1 MB)
|
10
|
+
|
11
|
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12
|
+
path: ../datasets/coco8-grayscale # dataset root dir
|
13
|
+
train: images/train # train images (relative to 'path') 4 images
|
14
|
+
val: images/val # val images (relative to 'path') 4 images
|
15
|
+
test: # test images (optional)
|
16
|
+
|
17
|
+
channels: 1
|
18
|
+
|
19
|
+
# Classes
|
20
|
+
names:
|
21
|
+
0: person
|
22
|
+
1: bicycle
|
23
|
+
2: car
|
24
|
+
3: motorcycle
|
25
|
+
4: airplane
|
26
|
+
5: bus
|
27
|
+
6: train
|
28
|
+
7: truck
|
29
|
+
8: boat
|
30
|
+
9: traffic light
|
31
|
+
10: fire hydrant
|
32
|
+
11: stop sign
|
33
|
+
12: parking meter
|
34
|
+
13: bench
|
35
|
+
14: bird
|
36
|
+
15: cat
|
37
|
+
16: dog
|
38
|
+
17: horse
|
39
|
+
18: sheep
|
40
|
+
19: cow
|
41
|
+
20: elephant
|
42
|
+
21: bear
|
43
|
+
22: zebra
|
44
|
+
23: giraffe
|
45
|
+
24: backpack
|
46
|
+
25: umbrella
|
47
|
+
26: handbag
|
48
|
+
27: tie
|
49
|
+
28: suitcase
|
50
|
+
29: frisbee
|
51
|
+
30: skis
|
52
|
+
31: snowboard
|
53
|
+
32: sports ball
|
54
|
+
33: kite
|
55
|
+
34: baseball bat
|
56
|
+
35: baseball glove
|
57
|
+
36: skateboard
|
58
|
+
37: surfboard
|
59
|
+
38: tennis racket
|
60
|
+
39: bottle
|
61
|
+
40: wine glass
|
62
|
+
41: cup
|
63
|
+
42: fork
|
64
|
+
43: knife
|
65
|
+
44: spoon
|
66
|
+
45: bowl
|
67
|
+
46: banana
|
68
|
+
47: apple
|
69
|
+
48: sandwich
|
70
|
+
49: orange
|
71
|
+
50: broccoli
|
72
|
+
51: carrot
|
73
|
+
52: hot dog
|
74
|
+
53: pizza
|
75
|
+
54: donut
|
76
|
+
55: cake
|
77
|
+
56: chair
|
78
|
+
57: couch
|
79
|
+
58: potted plant
|
80
|
+
59: bed
|
81
|
+
60: dining table
|
82
|
+
61: toilet
|
83
|
+
62: tv
|
84
|
+
63: laptop
|
85
|
+
64: mouse
|
86
|
+
65: remote
|
87
|
+
66: keyboard
|
88
|
+
67: cell phone
|
89
|
+
68: microwave
|
90
|
+
69: oven
|
91
|
+
70: toaster
|
92
|
+
71: sink
|
93
|
+
72: refrigerator
|
94
|
+
73: book
|
95
|
+
74: clock
|
96
|
+
75: vase
|
97
|
+
76: scissors
|
98
|
+
77: teddy bear
|
99
|
+
78: hair drier
|
100
|
+
79: toothbrush
|
101
|
+
|
102
|
+
# Download script/URL (optional)
|
103
|
+
download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-grayscale.zip
|
ultralytics/data/augment.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import cv2
|
9
9
|
import numpy as np
|
@@ -2416,7 +2416,7 @@ class RandomLoadText:
|
|
2416
2416
|
self.padding = padding
|
2417
2417
|
self.padding_value = padding_value
|
2418
2418
|
|
2419
|
-
def __call__(self, labels:
|
2419
|
+
def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:
|
2420
2420
|
"""
|
2421
2421
|
Randomly sample positive and negative texts and update class indices accordingly.
|
2422
2422
|
|
ultralytics/engine/model.py
CHANGED
@@ -673,8 +673,8 @@ class Model(torch.nn.Module):
|
|
673
673
|
custom = {"verbose": False} # method defaults
|
674
674
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
675
675
|
fmts = export_formats()
|
676
|
-
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, []))
|
677
|
-
export_kwargs = {k: v for k, v in args.items() if k in export_args
|
676
|
+
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
|
677
|
+
export_kwargs = {k: v for k, v in args.items() if k in export_args}
|
678
678
|
return benchmark(
|
679
679
|
model=self,
|
680
680
|
data=data, # if no 'data' argument passed set data=None for default datasets
|
@@ -1033,7 +1033,7 @@ class Model(torch.nn.Module):
|
|
1033
1033
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
1034
1034
|
|
1035
1035
|
@staticmethod
|
1036
|
-
def _reset_ckpt_args(args:
|
1036
|
+
def _reset_ckpt_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
1037
1037
|
"""
|
1038
1038
|
Reset specific arguments when loading a PyTorch model checkpoint.
|
1039
1039
|
|
ultralytics/engine/validator.py
CHANGED
ultralytics/models/nas/model.py
CHANGED
@@ -1,12 +1,4 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
-
"""
|
3
|
-
YOLO-NAS model interface.
|
4
|
-
|
5
|
-
Examples:
|
6
|
-
>>> from ultralytics import NAS
|
7
|
-
>>> model = NAS("yolo_nas_s")
|
8
|
-
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
9
|
-
"""
|
10
2
|
|
11
3
|
from pathlib import Path
|
12
4
|
from typing import Any, Dict
|
@@ -106,14 +106,10 @@ class ClassificationValidator(BaseValidator):
|
|
106
106
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
107
107
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
108
108
|
|
109
|
-
def finalize_metrics(self
|
109
|
+
def finalize_metrics(self) -> None:
|
110
110
|
"""
|
111
111
|
Finalize metrics including confusion matrix and processing speed.
|
112
112
|
|
113
|
-
Args:
|
114
|
-
*args (Any): Variable length argument list.
|
115
|
-
**kwargs (Any): Arbitrary keyword arguments.
|
116
|
-
|
117
113
|
Notes:
|
118
114
|
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
119
115
|
optionally plots it, and updates the metrics object with speed information.
|
@@ -42,7 +42,7 @@ class DetectionValidator(BaseValidator):
|
|
42
42
|
>>> validator()
|
43
43
|
"""
|
44
44
|
|
45
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
45
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
46
46
|
"""
|
47
47
|
Initialize detection validator with necessary variables and settings.
|
48
48
|
|
@@ -227,14 +227,13 @@ class DetectionValidator(BaseValidator):
|
|
227
227
|
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
228
228
|
)
|
229
229
|
|
230
|
-
def finalize_metrics(self
|
231
|
-
"""
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
"""
|
230
|
+
def finalize_metrics(self) -> None:
|
231
|
+
"""Set final values for metrics speed and confusion matrix."""
|
232
|
+
if self.args.plots:
|
233
|
+
for normalize in True, False:
|
234
|
+
self.confusion_matrix.plot(
|
235
|
+
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
236
|
+
)
|
238
237
|
self.metrics.speed = self.speed
|
239
238
|
self.metrics.confusion_matrix = self.confusion_matrix
|
240
239
|
|
@@ -267,12 +266,6 @@ class DetectionValidator(BaseValidator):
|
|
267
266
|
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
|
268
267
|
)
|
269
268
|
|
270
|
-
if self.args.plots:
|
271
|
-
for normalize in True, False:
|
272
|
-
self.confusion_matrix.plot(
|
273
|
-
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
274
|
-
)
|
275
|
-
|
276
269
|
def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
|
277
270
|
"""
|
278
271
|
Return correct prediction matrix.
|
@@ -290,7 +283,7 @@ class DetectionValidator(BaseValidator):
|
|
290
283
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
291
284
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
292
285
|
|
293
|
-
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
|
286
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
|
294
287
|
"""
|
295
288
|
Build YOLO Dataset.
|
296
289
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Dict, List, Tuple, Union
|
4
|
+
from typing import Any, Dict, List, Tuple, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -40,7 +40,7 @@ class OBBValidator(DetectionValidator):
|
|
40
40
|
>>> validator(model=args["model"])
|
41
41
|
"""
|
42
42
|
|
43
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
43
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
44
44
|
"""
|
45
45
|
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
46
46
|
|
@@ -58,8 +58,13 @@ class OBBValidator(DetectionValidator):
|
|
58
58
|
self.args.task = "obb"
|
59
59
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
|
60
60
|
|
61
|
-
def init_metrics(self, model):
|
62
|
-
"""
|
61
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
62
|
+
"""
|
63
|
+
Initialize evaluation metrics for YOLO obb validation.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model (torch.nn.Module): Model to validate.
|
67
|
+
"""
|
63
68
|
super().init_metrics(model)
|
64
69
|
val = self.data.get(self.args.split, "") # validation path
|
65
70
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
@@ -94,7 +99,7 @@ class OBBValidator(DetectionValidator):
|
|
94
99
|
|
95
100
|
Args:
|
96
101
|
si (int): Batch index to process.
|
97
|
-
batch (
|
102
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys:
|
98
103
|
- batch_idx: Tensor of batch indices
|
99
104
|
- cls: Tensor of class labels
|
100
105
|
- bboxes: Tensor of bounding boxes
|
@@ -103,7 +108,7 @@ class OBBValidator(DetectionValidator):
|
|
103
108
|
- ratio_pad: Ratio and padding information
|
104
109
|
|
105
110
|
Returns:
|
106
|
-
(
|
111
|
+
(Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
107
112
|
"""
|
108
113
|
idx = batch["batch_idx"] == si
|
109
114
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -116,7 +121,7 @@ class OBBValidator(DetectionValidator):
|
|
116
121
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
117
122
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
118
123
|
|
119
|
-
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
|
124
|
+
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
|
120
125
|
"""
|
121
126
|
Prepare predictions by scaling bounding boxes to original image dimensions.
|
122
127
|
|
@@ -125,7 +130,7 @@ class OBBValidator(DetectionValidator):
|
|
125
130
|
|
126
131
|
Args:
|
127
132
|
pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
|
128
|
-
pbatch (
|
133
|
+
pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
|
129
134
|
- imgsz (tuple): Model input image size.
|
130
135
|
- ori_shape (tuple): Original image shape.
|
131
136
|
- ratio_pad (tuple): Ratio and padding information for scaling.
|
@@ -139,13 +144,13 @@ class OBBValidator(DetectionValidator):
|
|
139
144
|
) # native-space pred
|
140
145
|
return predn
|
141
146
|
|
142
|
-
def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
|
147
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
143
148
|
"""
|
144
149
|
Plot predicted bounding boxes on input images and save the result.
|
145
150
|
|
146
151
|
Args:
|
147
|
-
batch (
|
148
|
-
preds (
|
152
|
+
batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
153
|
+
preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.
|
149
154
|
ni (int): Batch index used for naming the output file.
|
150
155
|
|
151
156
|
Examples:
|
@@ -163,7 +168,7 @@ class OBBValidator(DetectionValidator):
|
|
163
168
|
on_plot=self.on_plot,
|
164
169
|
) # pred
|
165
170
|
|
166
|
-
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
|
171
|
+
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
|
167
172
|
"""
|
168
173
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
169
174
|
|
@@ -192,7 +197,9 @@ class OBBValidator(DetectionValidator):
|
|
192
197
|
}
|
193
198
|
)
|
194
199
|
|
195
|
-
def save_one_txt(
|
200
|
+
def save_one_txt(
|
201
|
+
self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
|
202
|
+
) -> None:
|
196
203
|
"""
|
197
204
|
Save YOLO OBB detections to a text file in normalized coordinates.
|
198
205
|
|
@@ -200,7 +207,7 @@ class OBBValidator(DetectionValidator):
|
|
200
207
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
201
208
|
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
202
209
|
save_conf (bool): Whether to save confidence scores in the text file.
|
203
|
-
shape (
|
210
|
+
shape (Tuple[int, int]): Original image shape in format (height, width).
|
204
211
|
file (Path | str): Output file path to save detections.
|
205
212
|
|
206
213
|
Examples:
|
@@ -222,15 +229,15 @@ class OBBValidator(DetectionValidator):
|
|
222
229
|
obb=obb,
|
223
230
|
).save_txt(file, save_conf=save_conf)
|
224
231
|
|
225
|
-
def eval_json(self, stats: Dict) -> Dict:
|
232
|
+
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
226
233
|
"""
|
227
234
|
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
228
235
|
|
229
236
|
Args:
|
230
|
-
stats (
|
237
|
+
stats (Dict[str, Any]): Performance statistics dictionary.
|
231
238
|
|
232
239
|
Returns:
|
233
|
-
(
|
240
|
+
(Dict[str, Any]): Updated performance statistics.
|
234
241
|
"""
|
235
242
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
236
243
|
import json
|
@@ -49,7 +49,7 @@ class PoseValidator(DetectionValidator):
|
|
49
49
|
>>> validator()
|
50
50
|
"""
|
51
51
|
|
52
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
52
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
53
53
|
"""
|
54
54
|
Initialize a PoseValidator object for pose estimation validation.
|
55
55
|
|
@@ -107,8 +107,13 @@ class PoseValidator(DetectionValidator):
|
|
107
107
|
"mAP50-95)",
|
108
108
|
)
|
109
109
|
|
110
|
-
def init_metrics(self, model):
|
111
|
-
"""
|
110
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
111
|
+
"""
|
112
|
+
Initialize evaluation metrics for YOLO pose validation.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
model (torch.nn.Module): Model to validate.
|
116
|
+
"""
|
112
117
|
super().init_metrics(model)
|
113
118
|
self.kpt_shape = self.data["kpt_shape"]
|
114
119
|
is_pose = self.kpt_shape == [17, 3]
|
@@ -122,10 +127,10 @@ class PoseValidator(DetectionValidator):
|
|
122
127
|
|
123
128
|
Args:
|
124
129
|
si (int): Batch index.
|
125
|
-
batch (
|
130
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
126
131
|
|
127
132
|
Returns:
|
128
|
-
(
|
133
|
+
(Dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
129
134
|
|
130
135
|
Notes:
|
131
136
|
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
@@ -151,7 +156,7 @@ class PoseValidator(DetectionValidator):
|
|
151
156
|
|
152
157
|
Args:
|
153
158
|
pred (torch.Tensor): Raw prediction tensor from the model.
|
154
|
-
pbatch (
|
159
|
+
pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:
|
155
160
|
- imgsz: Image size used for inference
|
156
161
|
- ori_shape: Original image shape
|
157
162
|
- ratio_pad: Ratio and padding information for coordinate scaling
|
@@ -166,7 +171,7 @@ class PoseValidator(DetectionValidator):
|
|
166
171
|
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
167
172
|
return predn, pred_kpts
|
168
173
|
|
169
|
-
def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]):
|
174
|
+
def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
|
170
175
|
"""
|
171
176
|
Update metrics with new predictions and ground truth data.
|
172
177
|
|
@@ -175,7 +180,7 @@ class PoseValidator(DetectionValidator):
|
|
175
180
|
|
176
181
|
Args:
|
177
182
|
preds (List[torch.Tensor]): List of prediction tensors from the model.
|
178
|
-
batch (
|
183
|
+
batch (Dict[str, Any]): Batch data containing images and ground truth annotations.
|
179
184
|
"""
|
180
185
|
for si, pred in enumerate(preds):
|
181
186
|
self.seen += 1
|
@@ -266,12 +271,12 @@ class PoseValidator(DetectionValidator):
|
|
266
271
|
|
267
272
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
268
273
|
|
269
|
-
def plot_val_samples(self, batch: Dict[str, Any], ni: int):
|
274
|
+
def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
|
270
275
|
"""
|
271
276
|
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
272
277
|
|
273
278
|
Args:
|
274
|
-
batch (
|
279
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys:
|
275
280
|
- img (torch.Tensor): Batch of images
|
276
281
|
- batch_idx (torch.Tensor): Batch indices for each image
|
277
282
|
- cls (torch.Tensor): Class labels
|
@@ -292,12 +297,12 @@ class PoseValidator(DetectionValidator):
|
|
292
297
|
on_plot=self.on_plot,
|
293
298
|
)
|
294
299
|
|
295
|
-
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int):
|
300
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
296
301
|
"""
|
297
302
|
Plot and save model predictions with bounding boxes and keypoints.
|
298
303
|
|
299
304
|
Args:
|
300
|
-
batch (
|
305
|
+
batch (Dict[str, Any]): Dictionary containing batch data including images, file paths, and other metadata.
|
301
306
|
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
302
307
|
confidence scores, class predictions, and keypoints.
|
303
308
|
ni (int): Batch index used for naming the output file.
|
@@ -323,7 +328,7 @@ class PoseValidator(DetectionValidator):
|
|
323
328
|
save_conf: bool,
|
324
329
|
shape: Tuple[int, int],
|
325
330
|
file: Path,
|
326
|
-
):
|
331
|
+
) -> None:
|
327
332
|
"""
|
328
333
|
Save YOLO pose detections to a text file in normalized coordinates.
|
329
334
|
|
@@ -349,7 +354,7 @@ class PoseValidator(DetectionValidator):
|
|
349
354
|
keypoints=pred_kpts,
|
350
355
|
).save_txt(file, save_conf=save_conf)
|
351
356
|
|
352
|
-
def pred_to_json(self, predn: torch.Tensor, filename: str):
|
357
|
+
def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
|
353
358
|
"""
|
354
359
|
Convert YOLO predictions to COCO JSON format.
|
355
360
|
|