ultralytics 8.3.144__py3-none-any.whl → 8.3.146__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +3 -0
- tests/test_cli.py +2 -7
- tests/test_python.py +42 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +0 -1
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/data/augment.py +2 -2
- ultralytics/engine/model.py +14 -13
- ultralytics/engine/results.py +4 -4
- ultralytics/engine/validator.py +1 -1
- ultralytics/models/nas/model.py +0 -8
- ultralytics/models/yolo/classify/val.py +1 -5
- ultralytics/models/yolo/detect/val.py +9 -16
- ultralytics/models/yolo/obb/val.py +24 -17
- ultralytics/models/yolo/pose/val.py +19 -14
- ultralytics/models/yolo/segment/val.py +52 -44
- ultralytics/solutions/ai_gym.py +3 -5
- ultralytics/solutions/analytics.py +17 -9
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +1 -1
- ultralytics/solutions/object_counter.py +2 -8
- ultralytics/solutions/solutions.py +5 -4
- ultralytics/trackers/bot_sort.py +4 -2
- ultralytics/utils/__init__.py +1 -2
- ultralytics/utils/benchmarks.py +18 -16
- ultralytics/utils/checks.py +10 -5
- ultralytics/utils/downloads.py +1 -0
- ultralytics/utils/metrics.py +25 -26
- ultralytics/utils/plotting.py +10 -7
- ultralytics/utils/torch_utils.py +2 -2
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/METADATA +2 -2
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/RECORD +36 -35
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/WHEEL +1 -1
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.144.dist-info → ultralytics-8.3.146.dist-info}/top_level.txt +0 -0
tests/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
3
4
|
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks
|
4
5
|
|
5
6
|
# Constants used in tests
|
@@ -10,6 +11,8 @@ SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
|
|
10
11
|
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
|
11
12
|
CUDA_IS_AVAILABLE = checks.cuda_is_available()
|
12
13
|
CUDA_DEVICE_COUNT = checks.cuda_device_count()
|
14
|
+
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
+
MODELS = frozenset(list(TASK2MODEL.values()) + ["yolo11n-grayscale.pt"])
|
13
16
|
|
14
17
|
__all__ = (
|
15
18
|
"MODEL",
|
tests/test_cli.py
CHANGED
@@ -5,15 +5,10 @@ import subprocess
|
|
5
5
|
import pytest
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
-
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE
|
9
|
-
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
|
8
|
+
from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODELS, TASK_MODEL_DATA
|
10
9
|
from ultralytics.utils import ARM64, ASSETS, LINUX, WEIGHTS_DIR, checks
|
11
10
|
from ultralytics.utils.torch_utils import TORCH_1_9
|
12
11
|
|
13
|
-
# Constants
|
14
|
-
TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
|
15
|
-
MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS]
|
16
|
-
|
17
12
|
|
18
13
|
def run(cmd: str) -> None:
|
19
14
|
"""Execute a shell command using subprocess."""
|
@@ -44,7 +39,7 @@ def test_val(task: str, model: str, data: str) -> None:
|
|
44
39
|
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
45
40
|
def test_predict(task: str, model: str, data: str) -> None:
|
46
41
|
"""Test YOLO prediction on provided sample assets for specified task and model."""
|
47
|
-
run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
42
|
+
run(f"yolo {task} predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
|
48
43
|
|
49
44
|
|
50
45
|
@pytest.mark.parametrize("model", MODELS)
|
tests/test_python.py
CHANGED
@@ -12,9 +12,9 @@ import pytest
|
|
12
12
|
import torch
|
13
13
|
from PIL import Image
|
14
14
|
|
15
|
-
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
|
15
|
+
from tests import CFG, MODEL, MODELS, SOURCE, SOURCES_LIST, TASK_MODEL_DATA, TMP
|
16
16
|
from ultralytics import RTDETR, YOLO
|
17
|
-
from ultralytics.cfg import
|
17
|
+
from ultralytics.cfg import TASK2DATA, TASKS
|
18
18
|
from ultralytics.data.build import load_inference_source
|
19
19
|
from ultralytics.utils import (
|
20
20
|
ARM64,
|
@@ -112,21 +112,22 @@ def test_predict_csv_single_row():
|
|
112
112
|
@pytest.mark.parametrize("model_name", MODELS)
|
113
113
|
def test_predict_img(model_name):
|
114
114
|
"""Test YOLO model predictions on various image input types and sources, including online images."""
|
115
|
+
channels = 1 if model_name == "yolo11n-grayscale.pt" else 3
|
115
116
|
model = YOLO(WEIGHTS_DIR / model_name)
|
116
|
-
im = cv2.imread(str(SOURCE)) # uint8 numpy array
|
117
|
+
im = cv2.imread(str(SOURCE), flags=cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR) # uint8 numpy array
|
117
118
|
assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL
|
118
119
|
assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray
|
119
|
-
assert len(model(torch.rand((2,
|
120
|
+
assert len(model(torch.rand((2, channels, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order
|
120
121
|
assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch
|
121
122
|
assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream
|
122
|
-
assert len(model(torch.zeros(320, 640,
|
123
|
+
assert len(model(torch.zeros(320, 640, channels).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy
|
123
124
|
batch = [
|
124
125
|
str(SOURCE), # filename
|
125
126
|
Path(SOURCE), # Path
|
126
127
|
"https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE, # URI
|
127
|
-
|
128
|
+
im, # OpenCV
|
128
129
|
Image.open(SOURCE), # PIL
|
129
|
-
np.zeros((320, 640,
|
130
|
+
np.zeros((320, 640, channels), dtype=np.uint8), # numpy
|
130
131
|
]
|
131
132
|
assert len(model(batch, imgsz=32, classes=0)) == len(batch) # multiple sources in a batch
|
132
133
|
|
@@ -177,14 +178,17 @@ def test_youtube():
|
|
177
178
|
|
178
179
|
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
|
179
180
|
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
|
180
|
-
|
181
|
+
@pytest.mark.parametrize("model", MODELS)
|
182
|
+
def test_track_stream(model):
|
181
183
|
"""
|
182
184
|
Test streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.
|
183
185
|
|
184
186
|
Note imgsz=160 required for tracking for higher confidence and better matches.
|
185
187
|
"""
|
188
|
+
if model == "yolo11n-cls.pt": # classification model not supported for tracking
|
189
|
+
return
|
186
190
|
video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"
|
187
|
-
model = YOLO(
|
191
|
+
model = YOLO(model)
|
188
192
|
model.track(video_url, imgsz=160, tracker="bytetrack.yaml")
|
189
193
|
model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also
|
190
194
|
|
@@ -196,9 +200,10 @@ def test_track_stream():
|
|
196
200
|
model.track(video_url, imgsz=160, tracker=custom_yaml)
|
197
201
|
|
198
202
|
|
199
|
-
|
203
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
204
|
+
def test_val(task: str, model: str, data: str) -> None:
|
200
205
|
"""Test the validation mode of the YOLO model."""
|
201
|
-
metrics = YOLO(
|
206
|
+
metrics = YOLO(model).val(data=data, imgsz=32)
|
202
207
|
metrics.to_df()
|
203
208
|
metrics.to_csv()
|
204
209
|
metrics.to_xml()
|
@@ -268,7 +273,7 @@ def test_predict_callback_and_setup():
|
|
268
273
|
|
269
274
|
|
270
275
|
@pytest.mark.parametrize("model", MODELS)
|
271
|
-
def test_results(model):
|
276
|
+
def test_results(model: str):
|
272
277
|
"""Test YOLO model results processing and output in various formats."""
|
273
278
|
temp_s = "https://ultralytics.com/images/boats.jpg" if model == "yolo11n-obb.pt" else SOURCE
|
274
279
|
results = YOLO(WEIGHTS_DIR / model)([temp_s, temp_s], imgsz=160)
|
@@ -699,3 +704,28 @@ def test_multichannel():
|
|
699
704
|
im = np.zeros((32, 32, 10), dtype=np.uint8)
|
700
705
|
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
701
706
|
model.export(format="onnx")
|
707
|
+
|
708
|
+
|
709
|
+
@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
|
710
|
+
def test_grayscale(task: str, model: str, data: str) -> None:
|
711
|
+
"""Test YOLO model grayscale training, validation, and prediction functionality."""
|
712
|
+
if task == "classify": # not support grayscale classification yet
|
713
|
+
return
|
714
|
+
grayscale_data = Path(TMP) / f"{Path(data).stem}-grayscale.yaml"
|
715
|
+
data = YAML.load(checks.check_file(data))
|
716
|
+
data["channels"] = 1 # add additional channels key for grayscale
|
717
|
+
YAML.save(grayscale_data, data)
|
718
|
+
# remove npy files in train/val splits if exists, might be created by previous tests
|
719
|
+
for split in {"train", "val"}:
|
720
|
+
for npy_file in (Path(data["path"]) / data[split]).glob("*.npy"):
|
721
|
+
npy_file.unlink()
|
722
|
+
|
723
|
+
model = YOLO(model)
|
724
|
+
model.train(data=grayscale_data, epochs=1, imgsz=32, close_mosaic=1)
|
725
|
+
model.val(data=grayscale_data)
|
726
|
+
im = np.zeros((32, 32, 1), dtype=np.uint8)
|
727
|
+
model.predict(source=im, imgsz=32, save_txt=True, save_crop=True, augment=True)
|
728
|
+
export_model = model.export(format="onnx")
|
729
|
+
|
730
|
+
model = YOLO(export_model, task=task)
|
731
|
+
model.predict(source=im, imgsz=32)
|
ultralytics/__init__.py
CHANGED
ultralytics/cfg/__init__.py
CHANGED
@@ -0,0 +1,103 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# COCO8-Grayscale dataset (first 8 images from COCO train2017) by Ultralytics
|
4
|
+
# Documentation: https://docs.ultralytics.com/datasets/detect/coco8-grayscale/
|
5
|
+
# Example usage: yolo train data=coco8-grayscale.yaml
|
6
|
+
# parent
|
7
|
+
# ├── ultralytics
|
8
|
+
# └── datasets
|
9
|
+
# └── coco8-grayscale ← downloads here (1 MB)
|
10
|
+
|
11
|
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12
|
+
path: ../datasets/coco8-grayscale # dataset root dir
|
13
|
+
train: images/train # train images (relative to 'path') 4 images
|
14
|
+
val: images/val # val images (relative to 'path') 4 images
|
15
|
+
test: # test images (optional)
|
16
|
+
|
17
|
+
channels: 1
|
18
|
+
|
19
|
+
# Classes
|
20
|
+
names:
|
21
|
+
0: person
|
22
|
+
1: bicycle
|
23
|
+
2: car
|
24
|
+
3: motorcycle
|
25
|
+
4: airplane
|
26
|
+
5: bus
|
27
|
+
6: train
|
28
|
+
7: truck
|
29
|
+
8: boat
|
30
|
+
9: traffic light
|
31
|
+
10: fire hydrant
|
32
|
+
11: stop sign
|
33
|
+
12: parking meter
|
34
|
+
13: bench
|
35
|
+
14: bird
|
36
|
+
15: cat
|
37
|
+
16: dog
|
38
|
+
17: horse
|
39
|
+
18: sheep
|
40
|
+
19: cow
|
41
|
+
20: elephant
|
42
|
+
21: bear
|
43
|
+
22: zebra
|
44
|
+
23: giraffe
|
45
|
+
24: backpack
|
46
|
+
25: umbrella
|
47
|
+
26: handbag
|
48
|
+
27: tie
|
49
|
+
28: suitcase
|
50
|
+
29: frisbee
|
51
|
+
30: skis
|
52
|
+
31: snowboard
|
53
|
+
32: sports ball
|
54
|
+
33: kite
|
55
|
+
34: baseball bat
|
56
|
+
35: baseball glove
|
57
|
+
36: skateboard
|
58
|
+
37: surfboard
|
59
|
+
38: tennis racket
|
60
|
+
39: bottle
|
61
|
+
40: wine glass
|
62
|
+
41: cup
|
63
|
+
42: fork
|
64
|
+
43: knife
|
65
|
+
44: spoon
|
66
|
+
45: bowl
|
67
|
+
46: banana
|
68
|
+
47: apple
|
69
|
+
48: sandwich
|
70
|
+
49: orange
|
71
|
+
50: broccoli
|
72
|
+
51: carrot
|
73
|
+
52: hot dog
|
74
|
+
53: pizza
|
75
|
+
54: donut
|
76
|
+
55: cake
|
77
|
+
56: chair
|
78
|
+
57: couch
|
79
|
+
58: potted plant
|
80
|
+
59: bed
|
81
|
+
60: dining table
|
82
|
+
61: toilet
|
83
|
+
62: tv
|
84
|
+
63: laptop
|
85
|
+
64: mouse
|
86
|
+
65: remote
|
87
|
+
66: keyboard
|
88
|
+
67: cell phone
|
89
|
+
68: microwave
|
90
|
+
69: oven
|
91
|
+
70: toaster
|
92
|
+
71: sink
|
93
|
+
72: refrigerator
|
94
|
+
73: book
|
95
|
+
74: clock
|
96
|
+
75: vase
|
97
|
+
76: scissors
|
98
|
+
77: teddy bear
|
99
|
+
78: hair drier
|
100
|
+
79: toothbrush
|
101
|
+
|
102
|
+
# Download script/URL (optional)
|
103
|
+
download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-grayscale.zip
|
ultralytics/data/augment.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import List, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Tuple, Union
|
7
7
|
|
8
8
|
import cv2
|
9
9
|
import numpy as np
|
@@ -2416,7 +2416,7 @@ class RandomLoadText:
|
|
2416
2416
|
self.padding = padding
|
2417
2417
|
self.padding_value = padding_value
|
2418
2418
|
|
2419
|
-
def __call__(self, labels:
|
2419
|
+
def __call__(self, labels: Dict[str, Any]) -> Dict[str, Any]:
|
2420
2420
|
"""
|
2421
2421
|
Randomly sample positive and negative texts and update class indices accordingly.
|
2422
2422
|
|
ultralytics/engine/model.py
CHANGED
@@ -634,10 +634,7 @@ class Model(torch.nn.Module):
|
|
634
634
|
self.metrics = validator.metrics
|
635
635
|
return validator.metrics
|
636
636
|
|
637
|
-
def benchmark(
|
638
|
-
self,
|
639
|
-
**kwargs: Any,
|
640
|
-
):
|
637
|
+
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
|
641
638
|
"""
|
642
639
|
Benchmark the model across various export formats to evaluate performance.
|
643
640
|
|
@@ -647,14 +644,14 @@ class Model(torch.nn.Module):
|
|
647
644
|
defaults, and any additional user-provided keyword arguments.
|
648
645
|
|
649
646
|
Args:
|
647
|
+
data (str): Path to the dataset for benchmarking.
|
648
|
+
verbose (bool): Whether to print detailed benchmark information.
|
649
|
+
format (str): Export format name for specific benchmarking.
|
650
650
|
**kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:
|
651
|
-
- data (str): Path to the dataset for benchmarking.
|
652
651
|
- imgsz (int | List[int]): Image size for benchmarking.
|
653
652
|
- half (bool): Whether to use half-precision (FP16) mode.
|
654
653
|
- int8 (bool): Whether to use int8 precision mode.
|
655
654
|
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
656
|
-
- verbose (bool): Whether to print detailed benchmark information.
|
657
|
-
- format (str): Export format name for specific benchmarking.
|
658
655
|
|
659
656
|
Returns:
|
660
657
|
(dict): A dictionary containing the results of the benchmarking process, including metrics for
|
@@ -671,17 +668,21 @@ class Model(torch.nn.Module):
|
|
671
668
|
self._check_is_pytorch_model()
|
672
669
|
from ultralytics.utils.benchmarks import benchmark
|
673
670
|
|
671
|
+
from .exporter import export_formats
|
672
|
+
|
674
673
|
custom = {"verbose": False} # method defaults
|
675
674
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
675
|
+
fmts = export_formats()
|
676
|
+
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
|
677
|
+
export_kwargs = {k: v for k, v in args.items() if k in export_args}
|
676
678
|
return benchmark(
|
677
679
|
model=self,
|
678
|
-
data=
|
680
|
+
data=data, # if no 'data' argument passed set data=None for default datasets
|
679
681
|
imgsz=args["imgsz"],
|
680
|
-
half=args["half"],
|
681
|
-
int8=args["int8"],
|
682
682
|
device=args["device"],
|
683
|
-
verbose=
|
684
|
-
format=
|
683
|
+
verbose=verbose,
|
684
|
+
format=format,
|
685
|
+
**export_kwargs,
|
685
686
|
)
|
686
687
|
|
687
688
|
def export(
|
@@ -1032,7 +1033,7 @@ class Model(torch.nn.Module):
|
|
1032
1033
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
1033
1034
|
|
1034
1035
|
@staticmethod
|
1035
|
-
def _reset_ckpt_args(args:
|
1036
|
+
def _reset_ckpt_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
1036
1037
|
"""
|
1037
1038
|
Reset specific arguments when loading a PyTorch model checkpoint.
|
1038
1039
|
|
ultralytics/engine/results.py
CHANGED
@@ -558,7 +558,7 @@ class Results(SimpleClass, DataExportMixin):
|
|
558
558
|
)
|
559
559
|
idx = (
|
560
560
|
pred_boxes.id
|
561
|
-
if pred_boxes.
|
561
|
+
if pred_boxes.is_track and color_mode == "instance"
|
562
562
|
else pred_boxes.cls
|
563
563
|
if pred_boxes and color_mode == "class"
|
564
564
|
else reversed(range(len(pred_masks)))
|
@@ -568,10 +568,10 @@ class Results(SimpleClass, DataExportMixin):
|
|
568
568
|
# Plot Detect results
|
569
569
|
if pred_boxes is not None and show_boxes:
|
570
570
|
for i, d in enumerate(reversed(pred_boxes)):
|
571
|
-
c, d_conf, id = int(d.cls), float(d.conf) if conf else None,
|
571
|
+
c, d_conf, id = int(d.cls), float(d.conf) if conf else None, int(d.id.item()) if d.is_track else None
|
572
572
|
name = ("" if id is None else f"id:{id} ") + names[c]
|
573
573
|
label = (f"{name} {d_conf:.2f}" if conf else name) if labels else None
|
574
|
-
box = d.xyxyxyxy.
|
574
|
+
box = d.xyxyxyxy.squeeze() if is_obb else d.xyxy.squeeze()
|
575
575
|
annotator.box_label(
|
576
576
|
box,
|
577
577
|
label,
|
@@ -733,7 +733,7 @@ class Results(SimpleClass, DataExportMixin):
|
|
733
733
|
elif boxes:
|
734
734
|
# Detect/segment/pose
|
735
735
|
for j, d in enumerate(boxes):
|
736
|
-
c, conf, id = int(d.cls), float(d.conf),
|
736
|
+
c, conf, id = int(d.cls), float(d.conf), int(d.id.item()) if d.is_track else None
|
737
737
|
line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))
|
738
738
|
if masks:
|
739
739
|
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
ultralytics/engine/validator.py
CHANGED
ultralytics/models/nas/model.py
CHANGED
@@ -1,12 +1,4 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
-
"""
|
3
|
-
YOLO-NAS model interface.
|
4
|
-
|
5
|
-
Examples:
|
6
|
-
>>> from ultralytics import NAS
|
7
|
-
>>> model = NAS("yolo_nas_s")
|
8
|
-
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
9
|
-
"""
|
10
2
|
|
11
3
|
from pathlib import Path
|
12
4
|
from typing import Any, Dict
|
@@ -106,14 +106,10 @@ class ClassificationValidator(BaseValidator):
|
|
106
106
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
107
107
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
108
108
|
|
109
|
-
def finalize_metrics(self
|
109
|
+
def finalize_metrics(self) -> None:
|
110
110
|
"""
|
111
111
|
Finalize metrics including confusion matrix and processing speed.
|
112
112
|
|
113
|
-
Args:
|
114
|
-
*args (Any): Variable length argument list.
|
115
|
-
**kwargs (Any): Arbitrary keyword arguments.
|
116
|
-
|
117
113
|
Notes:
|
118
114
|
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
119
115
|
optionally plots it, and updates the metrics object with speed information.
|
@@ -42,7 +42,7 @@ class DetectionValidator(BaseValidator):
|
|
42
42
|
>>> validator()
|
43
43
|
"""
|
44
44
|
|
45
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
45
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
46
46
|
"""
|
47
47
|
Initialize detection validator with necessary variables and settings.
|
48
48
|
|
@@ -227,14 +227,13 @@ class DetectionValidator(BaseValidator):
|
|
227
227
|
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
228
228
|
)
|
229
229
|
|
230
|
-
def finalize_metrics(self
|
231
|
-
"""
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
"""
|
230
|
+
def finalize_metrics(self) -> None:
|
231
|
+
"""Set final values for metrics speed and confusion matrix."""
|
232
|
+
if self.args.plots:
|
233
|
+
for normalize in True, False:
|
234
|
+
self.confusion_matrix.plot(
|
235
|
+
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
236
|
+
)
|
238
237
|
self.metrics.speed = self.speed
|
239
238
|
self.metrics.confusion_matrix = self.confusion_matrix
|
240
239
|
|
@@ -267,12 +266,6 @@ class DetectionValidator(BaseValidator):
|
|
267
266
|
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
|
268
267
|
)
|
269
268
|
|
270
|
-
if self.args.plots:
|
271
|
-
for normalize in True, False:
|
272
|
-
self.confusion_matrix.plot(
|
273
|
-
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
274
|
-
)
|
275
|
-
|
276
269
|
def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
|
277
270
|
"""
|
278
271
|
Return correct prediction matrix.
|
@@ -290,7 +283,7 @@ class DetectionValidator(BaseValidator):
|
|
290
283
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
291
284
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
292
285
|
|
293
|
-
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
|
286
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None) -> torch.utils.data.Dataset:
|
294
287
|
"""
|
295
288
|
Build YOLO Dataset.
|
296
289
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Dict, List, Tuple, Union
|
4
|
+
from typing import Any, Dict, List, Tuple, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -40,7 +40,7 @@ class OBBValidator(DetectionValidator):
|
|
40
40
|
>>> validator(model=args["model"])
|
41
41
|
"""
|
42
42
|
|
43
|
-
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
43
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
|
44
44
|
"""
|
45
45
|
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
46
46
|
|
@@ -58,8 +58,13 @@ class OBBValidator(DetectionValidator):
|
|
58
58
|
self.args.task = "obb"
|
59
59
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
|
60
60
|
|
61
|
-
def init_metrics(self, model):
|
62
|
-
"""
|
61
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
62
|
+
"""
|
63
|
+
Initialize evaluation metrics for YOLO obb validation.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model (torch.nn.Module): Model to validate.
|
67
|
+
"""
|
63
68
|
super().init_metrics(model)
|
64
69
|
val = self.data.get(self.args.split, "") # validation path
|
65
70
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
@@ -94,7 +99,7 @@ class OBBValidator(DetectionValidator):
|
|
94
99
|
|
95
100
|
Args:
|
96
101
|
si (int): Batch index to process.
|
97
|
-
batch (
|
102
|
+
batch (Dict[str, Any]): Dictionary containing batch data with keys:
|
98
103
|
- batch_idx: Tensor of batch indices
|
99
104
|
- cls: Tensor of class labels
|
100
105
|
- bboxes: Tensor of bounding boxes
|
@@ -103,7 +108,7 @@ class OBBValidator(DetectionValidator):
|
|
103
108
|
- ratio_pad: Ratio and padding information
|
104
109
|
|
105
110
|
Returns:
|
106
|
-
(
|
111
|
+
(Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
107
112
|
"""
|
108
113
|
idx = batch["batch_idx"] == si
|
109
114
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -116,7 +121,7 @@ class OBBValidator(DetectionValidator):
|
|
116
121
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
117
122
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
118
123
|
|
119
|
-
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
|
124
|
+
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
|
120
125
|
"""
|
121
126
|
Prepare predictions by scaling bounding boxes to original image dimensions.
|
122
127
|
|
@@ -125,7 +130,7 @@ class OBBValidator(DetectionValidator):
|
|
125
130
|
|
126
131
|
Args:
|
127
132
|
pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
|
128
|
-
pbatch (
|
133
|
+
pbatch (Dict[str, Any]): Dictionary containing batch information with keys:
|
129
134
|
- imgsz (tuple): Model input image size.
|
130
135
|
- ori_shape (tuple): Original image shape.
|
131
136
|
- ratio_pad (tuple): Ratio and padding information for scaling.
|
@@ -139,13 +144,13 @@ class OBBValidator(DetectionValidator):
|
|
139
144
|
) # native-space pred
|
140
145
|
return predn
|
141
146
|
|
142
|
-
def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
|
147
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
143
148
|
"""
|
144
149
|
Plot predicted bounding boxes on input images and save the result.
|
145
150
|
|
146
151
|
Args:
|
147
|
-
batch (
|
148
|
-
preds (
|
152
|
+
batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
153
|
+
preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.
|
149
154
|
ni (int): Batch index used for naming the output file.
|
150
155
|
|
151
156
|
Examples:
|
@@ -163,7 +168,7 @@ class OBBValidator(DetectionValidator):
|
|
163
168
|
on_plot=self.on_plot,
|
164
169
|
) # pred
|
165
170
|
|
166
|
-
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
|
171
|
+
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]) -> None:
|
167
172
|
"""
|
168
173
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
169
174
|
|
@@ -192,7 +197,9 @@ class OBBValidator(DetectionValidator):
|
|
192
197
|
}
|
193
198
|
)
|
194
199
|
|
195
|
-
def save_one_txt(
|
200
|
+
def save_one_txt(
|
201
|
+
self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]
|
202
|
+
) -> None:
|
196
203
|
"""
|
197
204
|
Save YOLO OBB detections to a text file in normalized coordinates.
|
198
205
|
|
@@ -200,7 +207,7 @@ class OBBValidator(DetectionValidator):
|
|
200
207
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
201
208
|
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
202
209
|
save_conf (bool): Whether to save confidence scores in the text file.
|
203
|
-
shape (
|
210
|
+
shape (Tuple[int, int]): Original image shape in format (height, width).
|
204
211
|
file (Path | str): Output file path to save detections.
|
205
212
|
|
206
213
|
Examples:
|
@@ -222,15 +229,15 @@ class OBBValidator(DetectionValidator):
|
|
222
229
|
obb=obb,
|
223
230
|
).save_txt(file, save_conf=save_conf)
|
224
231
|
|
225
|
-
def eval_json(self, stats: Dict) -> Dict:
|
232
|
+
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
226
233
|
"""
|
227
234
|
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
228
235
|
|
229
236
|
Args:
|
230
|
-
stats (
|
237
|
+
stats (Dict[str, Any]): Performance statistics dictionary.
|
231
238
|
|
232
239
|
Returns:
|
233
|
-
(
|
240
|
+
(Dict[str, Any]): Updated performance statistics.
|
234
241
|
"""
|
235
242
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
236
243
|
import json
|