dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
|
@@ -6,20 +6,20 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
9
10
|
|
|
10
11
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
11
12
|
from ultralytics.engine.validator import BaseValidator
|
|
12
|
-
from ultralytics.utils import LOGGER
|
|
13
|
+
from ultralytics.utils import LOGGER, RANK
|
|
13
14
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
|
14
15
|
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class ClassificationValidator(BaseValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the BaseValidator class for validation based on a classification model.
|
|
19
|
+
"""A class extending the BaseValidator class for validation based on a classification model.
|
|
20
20
|
|
|
21
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
|
22
|
-
|
|
21
|
+
This validator handles the validation process for classification models, including metrics calculation, confusion
|
|
22
|
+
matrix generation, and visualization of results.
|
|
23
23
|
|
|
24
24
|
Attributes:
|
|
25
25
|
targets (list[torch.Tensor]): Ground truth class labels.
|
|
@@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
45
45
|
|
|
46
46
|
Examples:
|
|
47
47
|
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
48
|
-
>>> args = dict(model="
|
|
48
|
+
>>> args = dict(model="yolo26n-cls.pt", data="imagenet10")
|
|
49
49
|
>>> validator = ClassificationValidator(args=args)
|
|
50
50
|
>>> validator()
|
|
51
51
|
|
|
@@ -54,20 +54,13 @@ class ClassificationValidator(BaseValidator):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
57
|
-
"""
|
|
58
|
-
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
57
|
+
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
60
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
62
61
|
save_dir (str | Path, optional): Directory to save results.
|
|
63
62
|
args (dict, optional): Arguments containing model and validation configuration.
|
|
64
63
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
65
|
-
|
|
66
|
-
Examples:
|
|
67
|
-
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
68
|
-
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
|
69
|
-
>>> validator = ClassificationValidator(args=args)
|
|
70
|
-
>>> validator()
|
|
71
64
|
"""
|
|
72
65
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
73
66
|
self.targets = None
|
|
@@ -95,8 +88,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
95
88
|
return batch
|
|
96
89
|
|
|
97
90
|
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
|
98
|
-
"""
|
|
99
|
-
Update running metrics with model predictions and batch targets.
|
|
91
|
+
"""Update running metrics with model predictions and batch targets.
|
|
100
92
|
|
|
101
93
|
Args:
|
|
102
94
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
|
@@ -111,12 +103,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
111
103
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
|
112
104
|
|
|
113
105
|
def finalize_metrics(self) -> None:
|
|
114
|
-
"""
|
|
115
|
-
Finalize metrics including confusion matrix and processing speed.
|
|
116
|
-
|
|
117
|
-
Notes:
|
|
118
|
-
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
119
|
-
optionally plots it, and updates the metrics object with speed information.
|
|
106
|
+
"""Finalize metrics including confusion matrix and processing speed.
|
|
120
107
|
|
|
121
108
|
Examples:
|
|
122
109
|
>>> validator = ClassificationValidator()
|
|
@@ -124,6 +111,10 @@ class ClassificationValidator(BaseValidator):
|
|
|
124
111
|
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
|
125
112
|
>>> validator.finalize_metrics()
|
|
126
113
|
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
|
114
|
+
|
|
115
|
+
Notes:
|
|
116
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
117
|
+
optionally plots it, and updates the metrics object with speed information.
|
|
127
118
|
"""
|
|
128
119
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
|
129
120
|
if self.args.plots:
|
|
@@ -142,13 +133,25 @@ class ClassificationValidator(BaseValidator):
|
|
|
142
133
|
self.metrics.process(self.targets, self.pred)
|
|
143
134
|
return self.metrics.results_dict
|
|
144
135
|
|
|
136
|
+
def gather_stats(self) -> None:
|
|
137
|
+
"""Gather stats from all GPUs."""
|
|
138
|
+
if RANK == 0:
|
|
139
|
+
gathered_preds = [None] * dist.get_world_size()
|
|
140
|
+
gathered_targets = [None] * dist.get_world_size()
|
|
141
|
+
dist.gather_object(self.pred, gathered_preds, dst=0)
|
|
142
|
+
dist.gather_object(self.targets, gathered_targets, dst=0)
|
|
143
|
+
self.pred = [pred for rank in gathered_preds for pred in rank]
|
|
144
|
+
self.targets = [targets for rank in gathered_targets for targets in rank]
|
|
145
|
+
elif RANK > 0:
|
|
146
|
+
dist.gather_object(self.pred, None, dst=0)
|
|
147
|
+
dist.gather_object(self.targets, None, dst=0)
|
|
148
|
+
|
|
145
149
|
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
|
146
150
|
"""Create a ClassificationDataset instance for validation."""
|
|
147
151
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
148
152
|
|
|
149
153
|
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
150
|
-
"""
|
|
151
|
-
Build and return a data loader for classification validation.
|
|
154
|
+
"""Build and return a data loader for classification validation.
|
|
152
155
|
|
|
153
156
|
Args:
|
|
154
157
|
dataset_path (str | Path): Path to the dataset directory.
|
|
@@ -166,8 +169,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
166
169
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
|
167
170
|
|
|
168
171
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Plot validation image samples with their ground truth labels.
|
|
172
|
+
"""Plot validation image samples with their ground truth labels.
|
|
171
173
|
|
|
172
174
|
Args:
|
|
173
175
|
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
@@ -187,8 +189,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
187
189
|
)
|
|
188
190
|
|
|
189
191
|
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
|
190
|
-
"""
|
|
191
|
-
Plot images with their predicted class labels and save the visualization.
|
|
192
|
+
"""Plot images with their predicted class labels and save the visualization.
|
|
192
193
|
|
|
193
194
|
Args:
|
|
194
195
|
batch (dict[str, Any]): Batch data containing images and other information.
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class DetectionPredictor(BasePredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the BasePredictor class for prediction based on a detection model.
|
|
9
|
+
"""A class extending the BasePredictor class for prediction based on a detection model.
|
|
11
10
|
|
|
12
11
|
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
|
13
12
|
with bounding boxes and class predictions.
|
|
@@ -26,14 +25,13 @@ class DetectionPredictor(BasePredictor):
|
|
|
26
25
|
Examples:
|
|
27
26
|
>>> from ultralytics.utils import ASSETS
|
|
28
27
|
>>> from ultralytics.models.yolo.detect import DetectionPredictor
|
|
29
|
-
>>> args = dict(model="
|
|
28
|
+
>>> args = dict(model="yolo26n.pt", source=ASSETS)
|
|
30
29
|
>>> predictor = DetectionPredictor(overrides=args)
|
|
31
30
|
>>> predictor.predict_cli()
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
33
|
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
|
35
|
-
"""
|
|
36
|
-
Post-process predictions and return a list of Results objects.
|
|
34
|
+
"""Post-process predictions and return a list of Results objects.
|
|
37
35
|
|
|
38
36
|
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
|
39
37
|
further analysis.
|
|
@@ -48,7 +46,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
48
46
|
(list): List of Results objects containing the post-processed predictions.
|
|
49
47
|
|
|
50
48
|
Examples:
|
|
51
|
-
>>> predictor = DetectionPredictor(overrides=dict(model="
|
|
49
|
+
>>> predictor = DetectionPredictor(overrides=dict(model="yolo26n.pt"))
|
|
52
50
|
>>> results = predictor.predict("path/to/image.jpg")
|
|
53
51
|
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
|
54
52
|
"""
|
|
@@ -67,7 +65,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
67
65
|
)
|
|
68
66
|
|
|
69
67
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
70
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
68
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
71
69
|
|
|
72
70
|
if save_feats:
|
|
73
71
|
obj_feats = self.get_obj_feats(self._feats, preds[1])
|
|
@@ -81,7 +79,8 @@ class DetectionPredictor(BasePredictor):
|
|
|
81
79
|
|
|
82
80
|
return results
|
|
83
81
|
|
|
84
|
-
|
|
82
|
+
@staticmethod
|
|
83
|
+
def get_obj_feats(feat_maps, idxs):
|
|
85
84
|
"""Extract object features from the feature maps."""
|
|
86
85
|
import torch
|
|
87
86
|
|
|
@@ -92,8 +91,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
92
91
|
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
|
93
92
|
|
|
94
93
|
def construct_results(self, preds, img, orig_imgs):
|
|
95
|
-
"""
|
|
96
|
-
Construct a list of Results objects from model predictions.
|
|
94
|
+
"""Construct a list of Results objects from model predictions.
|
|
97
95
|
|
|
98
96
|
Args:
|
|
99
97
|
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
|
@@ -109,8 +107,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
109
107
|
]
|
|
110
108
|
|
|
111
109
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
112
|
-
"""
|
|
113
|
-
Construct a single Results object from one image prediction.
|
|
110
|
+
"""Construct a single Results object from one image prediction.
|
|
114
111
|
|
|
115
112
|
Args:
|
|
116
113
|
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
|
@@ -22,11 +22,10 @@ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_m
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class DetectionTrainer(BaseTrainer):
|
|
25
|
-
"""
|
|
26
|
-
A class extending the BaseTrainer class for training based on a detection model.
|
|
25
|
+
"""A class extending the BaseTrainer class for training based on a detection model.
|
|
27
26
|
|
|
28
|
-
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
|
29
|
-
|
|
27
|
+
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
|
|
28
|
+
object detection including dataset building, data loading, preprocessing, and model configuration.
|
|
30
29
|
|
|
31
30
|
Attributes:
|
|
32
31
|
model (DetectionModel): The YOLO detection model being trained.
|
|
@@ -48,14 +47,13 @@ class DetectionTrainer(BaseTrainer):
|
|
|
48
47
|
|
|
49
48
|
Examples:
|
|
50
49
|
>>> from ultralytics.models.yolo.detect import DetectionTrainer
|
|
51
|
-
>>> args = dict(model="
|
|
50
|
+
>>> args = dict(model="yolo26n.pt", data="coco8.yaml", epochs=3)
|
|
52
51
|
>>> trainer = DetectionTrainer(overrides=args)
|
|
53
52
|
>>> trainer.train()
|
|
54
53
|
"""
|
|
55
54
|
|
|
56
55
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
57
|
-
"""
|
|
58
|
-
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
|
56
|
+
"""Initialize a DetectionTrainer object for training YOLO object detection models.
|
|
59
57
|
|
|
60
58
|
Args:
|
|
61
59
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -65,8 +63,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
65
63
|
super().__init__(cfg, overrides, _callbacks)
|
|
66
64
|
|
|
67
65
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
68
|
-
"""
|
|
69
|
-
Build YOLO Dataset for training or validation.
|
|
66
|
+
"""Build YOLO Dataset for training or validation.
|
|
70
67
|
|
|
71
68
|
Args:
|
|
72
69
|
img_path (str): Path to the folder containing images.
|
|
@@ -80,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
80
77
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
|
81
78
|
|
|
82
79
|
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
83
|
-
"""
|
|
84
|
-
Construct and return dataloader for the specified mode.
|
|
80
|
+
"""Construct and return dataloader for the specified mode.
|
|
85
81
|
|
|
86
82
|
Args:
|
|
87
83
|
dataset_path (str): Path to the dataset.
|
|
@@ -109,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
109
105
|
)
|
|
110
106
|
|
|
111
107
|
def preprocess_batch(self, batch: dict) -> dict:
|
|
112
|
-
"""
|
|
113
|
-
Preprocess a batch of images by scaling and converting to float.
|
|
108
|
+
"""Preprocess a batch of images by scaling and converting to float.
|
|
114
109
|
|
|
115
110
|
Args:
|
|
116
111
|
batch (dict): Dictionary containing batch data with 'img' tensor.
|
|
@@ -122,10 +117,13 @@ class DetectionTrainer(BaseTrainer):
|
|
|
122
117
|
if isinstance(v, torch.Tensor):
|
|
123
118
|
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
|
124
119
|
batch["img"] = batch["img"].float() / 255
|
|
125
|
-
if self.args.multi_scale:
|
|
120
|
+
if self.args.multi_scale > 0.0:
|
|
126
121
|
imgs = batch["img"]
|
|
127
122
|
sz = (
|
|
128
|
-
random.randrange(
|
|
123
|
+
random.randrange(
|
|
124
|
+
int(self.args.imgsz * (1.0 - self.args.multi_scale)),
|
|
125
|
+
int(self.args.imgsz * (1.0 + self.args.multi_scale) + self.stride),
|
|
126
|
+
)
|
|
129
127
|
// self.stride
|
|
130
128
|
* self.stride
|
|
131
129
|
) # size
|
|
@@ -150,8 +148,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
150
148
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
|
151
149
|
|
|
152
150
|
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
|
153
|
-
"""
|
|
154
|
-
Return a YOLO detection model.
|
|
151
|
+
"""Return a YOLO detection model.
|
|
155
152
|
|
|
156
153
|
Args:
|
|
157
154
|
cfg (str, optional): Path to model configuration file.
|
|
@@ -174,8 +171,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
174
171
|
)
|
|
175
172
|
|
|
176
173
|
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
|
177
|
-
"""
|
|
178
|
-
Return a loss dict with labeled training loss items tensor.
|
|
174
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
179
175
|
|
|
180
176
|
Args:
|
|
181
177
|
loss_items (list[float], optional): List of loss values.
|
|
@@ -202,8 +198,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
202
198
|
)
|
|
203
199
|
|
|
204
200
|
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
205
|
-
"""
|
|
206
|
-
Plot training samples with their annotations.
|
|
201
|
+
"""Plot training samples with their annotations.
|
|
207
202
|
|
|
208
203
|
Args:
|
|
209
204
|
batch (dict[str, Any]): Dictionary containing batch data.
|
|
@@ -223,8 +218,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
223
218
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
|
224
219
|
|
|
225
220
|
def auto_batch(self):
|
|
226
|
-
"""
|
|
227
|
-
Get optimal batch size by calculating memory occupation of model.
|
|
221
|
+
"""Get optimal batch size by calculating memory occupation of model.
|
|
228
222
|
|
|
229
223
|
Returns:
|
|
230
224
|
(int): Optimal batch size.
|
|
@@ -8,18 +8,18 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
+
import torch.distributed as dist
|
|
11
12
|
|
|
12
13
|
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
|
13
14
|
from ultralytics.engine.validator import BaseValidator
|
|
14
|
-
from ultralytics.utils import LOGGER, nms, ops
|
|
15
|
+
from ultralytics.utils import LOGGER, RANK, nms, ops
|
|
15
16
|
from ultralytics.utils.checks import check_requirements
|
|
16
17
|
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
|
17
18
|
from ultralytics.utils.plotting import plot_images
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class DetectionValidator(BaseValidator):
|
|
21
|
-
"""
|
|
22
|
-
A class extending the BaseValidator class for validation based on a detection model.
|
|
22
|
+
"""A class extending the BaseValidator class for validation based on a detection model.
|
|
23
23
|
|
|
24
24
|
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
|
25
25
|
prediction processing, and visualization of results.
|
|
@@ -37,17 +37,16 @@ class DetectionValidator(BaseValidator):
|
|
|
37
37
|
|
|
38
38
|
Examples:
|
|
39
39
|
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
|
40
|
-
>>> args = dict(model="
|
|
40
|
+
>>> args = dict(model="yolo26n.pt", data="coco8.yaml")
|
|
41
41
|
>>> validator = DetectionValidator(args=args)
|
|
42
42
|
>>> validator()
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
45
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
46
|
-
"""
|
|
47
|
-
Initialize detection validator with necessary variables and settings.
|
|
46
|
+
"""Initialize detection validator with necessary variables and settings.
|
|
48
47
|
|
|
49
48
|
Args:
|
|
50
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
49
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
51
50
|
save_dir (Path, optional): Directory to save results.
|
|
52
51
|
args (dict[str, Any], optional): Arguments for the validator.
|
|
53
52
|
_callbacks (list[Any], optional): List of callback functions.
|
|
@@ -62,8 +61,7 @@ class DetectionValidator(BaseValidator):
|
|
|
62
61
|
self.metrics = DetMetrics()
|
|
63
62
|
|
|
64
63
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
-
"""
|
|
66
|
-
Preprocess batch of images for YOLO validation.
|
|
64
|
+
"""Preprocess batch of images for YOLO validation.
|
|
67
65
|
|
|
68
66
|
Args:
|
|
69
67
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -78,8 +76,7 @@ class DetectionValidator(BaseValidator):
|
|
|
78
76
|
return batch
|
|
79
77
|
|
|
80
78
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
81
|
-
"""
|
|
82
|
-
Initialize evaluation metrics for YOLO detection validation.
|
|
79
|
+
"""Initialize evaluation metrics for YOLO detection validation.
|
|
83
80
|
|
|
84
81
|
Args:
|
|
85
82
|
model (torch.nn.Module): Model to validate.
|
|
@@ -106,15 +103,14 @@ class DetectionValidator(BaseValidator):
|
|
|
106
103
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
|
107
104
|
|
|
108
105
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
109
|
-
"""
|
|
110
|
-
Apply Non-maximum suppression to prediction outputs.
|
|
106
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
111
107
|
|
|
112
108
|
Args:
|
|
113
109
|
preds (torch.Tensor): Raw predictions from the model.
|
|
114
110
|
|
|
115
111
|
Returns:
|
|
116
|
-
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
|
117
|
-
'
|
|
112
|
+
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
|
|
113
|
+
'cls', and 'extra' tensors.
|
|
118
114
|
"""
|
|
119
115
|
outputs = nms.non_max_suppression(
|
|
120
116
|
preds,
|
|
@@ -130,8 +126,7 @@ class DetectionValidator(BaseValidator):
|
|
|
130
126
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
|
131
127
|
|
|
132
128
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
133
|
-
"""
|
|
134
|
-
Prepare a batch of images and annotations for validation.
|
|
129
|
+
"""Prepare a batch of images and annotations for validation.
|
|
135
130
|
|
|
136
131
|
Args:
|
|
137
132
|
si (int): Batch index.
|
|
@@ -158,8 +153,7 @@ class DetectionValidator(BaseValidator):
|
|
|
158
153
|
}
|
|
159
154
|
|
|
160
155
|
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
161
|
-
"""
|
|
162
|
-
Prepare predictions for evaluation against ground truth.
|
|
156
|
+
"""Prepare predictions for evaluation against ground truth.
|
|
163
157
|
|
|
164
158
|
Args:
|
|
165
159
|
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
|
@@ -172,8 +166,7 @@ class DetectionValidator(BaseValidator):
|
|
|
172
166
|
return pred
|
|
173
167
|
|
|
174
168
|
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
|
175
|
-
"""
|
|
176
|
-
Update metrics with new predictions and ground truth.
|
|
169
|
+
"""Update metrics with new predictions and ground truth.
|
|
177
170
|
|
|
178
171
|
Args:
|
|
179
172
|
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
|
@@ -226,9 +219,30 @@ class DetectionValidator(BaseValidator):
|
|
|
226
219
|
self.metrics.confusion_matrix = self.confusion_matrix
|
|
227
220
|
self.metrics.save_dir = self.save_dir
|
|
228
221
|
|
|
222
|
+
def gather_stats(self) -> None:
|
|
223
|
+
"""Gather stats from all GPUs."""
|
|
224
|
+
if RANK == 0:
|
|
225
|
+
gathered_stats = [None] * dist.get_world_size()
|
|
226
|
+
dist.gather_object(self.metrics.stats, gathered_stats, dst=0)
|
|
227
|
+
merged_stats = {key: [] for key in self.metrics.stats.keys()}
|
|
228
|
+
for stats_dict in gathered_stats:
|
|
229
|
+
for key in merged_stats:
|
|
230
|
+
merged_stats[key].extend(stats_dict[key])
|
|
231
|
+
gathered_jdict = [None] * dist.get_world_size()
|
|
232
|
+
dist.gather_object(self.jdict, gathered_jdict, dst=0)
|
|
233
|
+
self.jdict = []
|
|
234
|
+
for jdict in gathered_jdict:
|
|
235
|
+
self.jdict.extend(jdict)
|
|
236
|
+
self.metrics.stats = merged_stats
|
|
237
|
+
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
|
238
|
+
elif RANK > 0:
|
|
239
|
+
dist.gather_object(self.metrics.stats, None, dst=0)
|
|
240
|
+
dist.gather_object(self.jdict, None, dst=0)
|
|
241
|
+
self.jdict = []
|
|
242
|
+
self.metrics.clear_stats()
|
|
243
|
+
|
|
229
244
|
def get_stats(self) -> dict[str, Any]:
|
|
230
|
-
"""
|
|
231
|
-
Calculate and return metrics statistics.
|
|
245
|
+
"""Calculate and return metrics statistics.
|
|
232
246
|
|
|
233
247
|
Returns:
|
|
234
248
|
(dict[str, Any]): Dictionary containing metrics results.
|
|
@@ -242,7 +256,7 @@ class DetectionValidator(BaseValidator):
|
|
|
242
256
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
|
243
257
|
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
|
|
244
258
|
if self.metrics.nt_per_class.sum() == 0:
|
|
245
|
-
LOGGER.warning(f"no labels found in {self.args.task} set,
|
|
259
|
+
LOGGER.warning(f"no labels found in {self.args.task} set, cannot compute metrics without labels")
|
|
246
260
|
|
|
247
261
|
# Print results per class
|
|
248
262
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
|
|
@@ -258,15 +272,15 @@ class DetectionValidator(BaseValidator):
|
|
|
258
272
|
)
|
|
259
273
|
|
|
260
274
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
261
|
-
"""
|
|
262
|
-
Return correct prediction matrix.
|
|
275
|
+
"""Return correct prediction matrix.
|
|
263
276
|
|
|
264
277
|
Args:
|
|
265
278
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
|
266
279
|
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
|
267
280
|
|
|
268
281
|
Returns:
|
|
269
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
282
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
283
|
+
10 IoU levels.
|
|
270
284
|
"""
|
|
271
285
|
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
272
286
|
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
|
@@ -274,8 +288,7 @@ class DetectionValidator(BaseValidator):
|
|
|
274
288
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
275
289
|
|
|
276
290
|
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
|
277
|
-
"""
|
|
278
|
-
Build YOLO Dataset.
|
|
291
|
+
"""Build YOLO Dataset.
|
|
279
292
|
|
|
280
293
|
Args:
|
|
281
294
|
img_path (str): Path to the folder containing images.
|
|
@@ -288,24 +301,28 @@ class DetectionValidator(BaseValidator):
|
|
|
288
301
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
|
289
302
|
|
|
290
303
|
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
291
|
-
"""
|
|
292
|
-
Construct and return dataloader.
|
|
304
|
+
"""Construct and return dataloader.
|
|
293
305
|
|
|
294
306
|
Args:
|
|
295
307
|
dataset_path (str): Path to the dataset.
|
|
296
308
|
batch_size (int): Size of each batch.
|
|
297
309
|
|
|
298
310
|
Returns:
|
|
299
|
-
(torch.utils.data.DataLoader):
|
|
311
|
+
(torch.utils.data.DataLoader): DataLoader for validation.
|
|
300
312
|
"""
|
|
301
313
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
|
302
314
|
return build_dataloader(
|
|
303
|
-
dataset,
|
|
315
|
+
dataset,
|
|
316
|
+
batch_size,
|
|
317
|
+
self.args.workers,
|
|
318
|
+
shuffle=False,
|
|
319
|
+
rank=-1,
|
|
320
|
+
drop_last=self.args.compile,
|
|
321
|
+
pin_memory=self.training,
|
|
304
322
|
)
|
|
305
323
|
|
|
306
324
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
307
|
-
"""
|
|
308
|
-
Plot validation image samples.
|
|
325
|
+
"""Plot validation image samples.
|
|
309
326
|
|
|
310
327
|
Args:
|
|
311
328
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -322,8 +339,7 @@ class DetectionValidator(BaseValidator):
|
|
|
322
339
|
def plot_predictions(
|
|
323
340
|
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
|
324
341
|
) -> None:
|
|
325
|
-
"""
|
|
326
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
342
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
327
343
|
|
|
328
344
|
Args:
|
|
329
345
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -331,14 +347,14 @@ class DetectionValidator(BaseValidator):
|
|
|
331
347
|
ni (int): Batch index.
|
|
332
348
|
max_det (Optional[int]): Maximum number of detections to plot.
|
|
333
349
|
"""
|
|
334
|
-
|
|
350
|
+
if not preds:
|
|
351
|
+
return
|
|
335
352
|
for i, pred in enumerate(preds):
|
|
336
353
|
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
|
|
337
354
|
keys = preds[0].keys()
|
|
338
355
|
max_det = max_det or self.args.max_det
|
|
339
356
|
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
|
|
340
|
-
#
|
|
341
|
-
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
|
|
357
|
+
batched_preds["bboxes"] = ops.xyxy2xywh(batched_preds["bboxes"]) # convert to xywh format
|
|
342
358
|
plot_images(
|
|
343
359
|
images=batch["img"],
|
|
344
360
|
labels=batched_preds,
|
|
@@ -349,8 +365,7 @@ class DetectionValidator(BaseValidator):
|
|
|
349
365
|
) # pred
|
|
350
366
|
|
|
351
367
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
352
|
-
"""
|
|
353
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
368
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
354
369
|
|
|
355
370
|
Args:
|
|
356
371
|
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
|
@@ -368,12 +383,11 @@ class DetectionValidator(BaseValidator):
|
|
|
368
383
|
).save_txt(file, save_conf=save_conf)
|
|
369
384
|
|
|
370
385
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
371
|
-
"""
|
|
372
|
-
Serialize YOLO predictions to COCO json format.
|
|
386
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
373
387
|
|
|
374
388
|
Args:
|
|
375
|
-
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
376
|
-
|
|
389
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
390
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
377
391
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
378
392
|
|
|
379
393
|
Examples:
|
|
@@ -414,8 +428,7 @@ class DetectionValidator(BaseValidator):
|
|
|
414
428
|
}
|
|
415
429
|
|
|
416
430
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
417
|
-
"""
|
|
418
|
-
Evaluate YOLO output in JSON format and return performance statistics.
|
|
431
|
+
"""Evaluate YOLO output in JSON format and return performance statistics.
|
|
419
432
|
|
|
420
433
|
Args:
|
|
421
434
|
stats (dict[str, Any]): Current statistics dictionary.
|
|
@@ -439,21 +452,20 @@ class DetectionValidator(BaseValidator):
|
|
|
439
452
|
iou_types: str | list[str] = "bbox",
|
|
440
453
|
suffix: str | list[str] = "Box",
|
|
441
454
|
) -> dict[str, Any]:
|
|
442
|
-
"""
|
|
443
|
-
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
455
|
+
"""Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
444
456
|
|
|
445
|
-
Performs evaluation using the faster-coco-eval library to compute mAP metrics
|
|
446
|
-
|
|
447
|
-
|
|
457
|
+
Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
|
|
458
|
+
provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
|
|
459
|
+
applicable.
|
|
448
460
|
|
|
449
461
|
Args:
|
|
450
462
|
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
|
451
|
-
pred_json (str | Path
|
|
452
|
-
anno_json (str | Path
|
|
453
|
-
iou_types (str | list[str]
|
|
454
|
-
|
|
455
|
-
suffix (str | list[str]
|
|
456
|
-
|
|
463
|
+
pred_json (str | Path): Path to JSON file containing predictions in COCO format.
|
|
464
|
+
anno_json (str | Path): Path to JSON file containing ground truth annotations in COCO format.
|
|
465
|
+
iou_types (str | list[str]): IoU type(s) for evaluation. Can be single string or list of strings. Common
|
|
466
|
+
values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
|
467
|
+
suffix (str | list[str]): Suffix to append to metric names in stats dictionary. Should correspond to
|
|
468
|
+
iou_types if multiple types provided. Defaults to "Box".
|
|
457
469
|
|
|
458
470
|
Returns:
|
|
459
471
|
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
|
@@ -482,6 +494,12 @@ class DetectionValidator(BaseValidator):
|
|
|
482
494
|
# update mAP50-95 and mAP50
|
|
483
495
|
stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
|
|
484
496
|
stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
|
|
497
|
+
# record mAP for small, medium, large objects as well
|
|
498
|
+
stats["metrics/mAP_small(B)"] = val.stats_as_dict["AP_small"]
|
|
499
|
+
stats["metrics/mAP_medium(B)"] = val.stats_as_dict["AP_medium"]
|
|
500
|
+
stats["metrics/mAP_large(B)"] = val.stats_as_dict["AP_large"]
|
|
501
|
+
# update fitness
|
|
502
|
+
stats["fitness"] = 0.9 * val.stats_as_dict["AP_all"] + 0.1 * val.stats_as_dict["AP_50"]
|
|
485
503
|
|
|
486
504
|
if self.is_lvis:
|
|
487
505
|
stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]
|