dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -8,18 +8,18 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
+
import torch.distributed as dist
|
|
11
12
|
|
|
12
13
|
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
|
13
14
|
from ultralytics.engine.validator import BaseValidator
|
|
14
|
-
from ultralytics.utils import LOGGER, nms, ops
|
|
15
|
+
from ultralytics.utils import LOGGER, RANK, nms, ops
|
|
15
16
|
from ultralytics.utils.checks import check_requirements
|
|
16
17
|
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
|
17
18
|
from ultralytics.utils.plotting import plot_images
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class DetectionValidator(BaseValidator):
|
|
21
|
-
"""
|
|
22
|
-
A class extending the BaseValidator class for validation based on a detection model.
|
|
22
|
+
"""A class extending the BaseValidator class for validation based on a detection model.
|
|
23
23
|
|
|
24
24
|
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
|
25
25
|
prediction processing, and visualization of results.
|
|
@@ -43,11 +43,10 @@ class DetectionValidator(BaseValidator):
|
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
45
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
46
|
-
"""
|
|
47
|
-
Initialize detection validator with necessary variables and settings.
|
|
46
|
+
"""Initialize detection validator with necessary variables and settings.
|
|
48
47
|
|
|
49
48
|
Args:
|
|
50
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
49
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
51
50
|
save_dir (Path, optional): Directory to save results.
|
|
52
51
|
args (dict[str, Any], optional): Arguments for the validator.
|
|
53
52
|
_callbacks (list[Any], optional): List of callback functions.
|
|
@@ -62,8 +61,7 @@ class DetectionValidator(BaseValidator):
|
|
|
62
61
|
self.metrics = DetMetrics()
|
|
63
62
|
|
|
64
63
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
-
"""
|
|
66
|
-
Preprocess batch of images for YOLO validation.
|
|
64
|
+
"""Preprocess batch of images for YOLO validation.
|
|
67
65
|
|
|
68
66
|
Args:
|
|
69
67
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -73,13 +71,12 @@ class DetectionValidator(BaseValidator):
|
|
|
73
71
|
"""
|
|
74
72
|
for k, v in batch.items():
|
|
75
73
|
if isinstance(v, torch.Tensor):
|
|
76
|
-
batch[k] = v.to(self.device, non_blocking=
|
|
74
|
+
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
|
77
75
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
|
78
76
|
return batch
|
|
79
77
|
|
|
80
78
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
81
|
-
"""
|
|
82
|
-
Initialize evaluation metrics for YOLO detection validation.
|
|
79
|
+
"""Initialize evaluation metrics for YOLO detection validation.
|
|
83
80
|
|
|
84
81
|
Args:
|
|
85
82
|
model (torch.nn.Module): Model to validate.
|
|
@@ -106,15 +103,14 @@ class DetectionValidator(BaseValidator):
|
|
|
106
103
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
|
107
104
|
|
|
108
105
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
109
|
-
"""
|
|
110
|
-
Apply Non-maximum suppression to prediction outputs.
|
|
106
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
111
107
|
|
|
112
108
|
Args:
|
|
113
109
|
preds (torch.Tensor): Raw predictions from the model.
|
|
114
110
|
|
|
115
111
|
Returns:
|
|
116
|
-
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
|
117
|
-
'
|
|
112
|
+
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
|
|
113
|
+
'cls', and 'extra' tensors.
|
|
118
114
|
"""
|
|
119
115
|
outputs = nms.non_max_suppression(
|
|
120
116
|
preds,
|
|
@@ -130,8 +126,7 @@ class DetectionValidator(BaseValidator):
|
|
|
130
126
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
|
131
127
|
|
|
132
128
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
133
|
-
"""
|
|
134
|
-
Prepare a batch of images and annotations for validation.
|
|
129
|
+
"""Prepare a batch of images and annotations for validation.
|
|
135
130
|
|
|
136
131
|
Args:
|
|
137
132
|
si (int): Batch index.
|
|
@@ -146,7 +141,7 @@ class DetectionValidator(BaseValidator):
|
|
|
146
141
|
ori_shape = batch["ori_shape"][si]
|
|
147
142
|
imgsz = batch["img"].shape[2:]
|
|
148
143
|
ratio_pad = batch["ratio_pad"][si]
|
|
149
|
-
if
|
|
144
|
+
if cls.shape[0]:
|
|
150
145
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
|
151
146
|
return {
|
|
152
147
|
"cls": cls,
|
|
@@ -158,8 +153,7 @@ class DetectionValidator(BaseValidator):
|
|
|
158
153
|
}
|
|
159
154
|
|
|
160
155
|
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
161
|
-
"""
|
|
162
|
-
Prepare predictions for evaluation against ground truth.
|
|
156
|
+
"""Prepare predictions for evaluation against ground truth.
|
|
163
157
|
|
|
164
158
|
Args:
|
|
165
159
|
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
|
@@ -172,8 +166,7 @@ class DetectionValidator(BaseValidator):
|
|
|
172
166
|
return pred
|
|
173
167
|
|
|
174
168
|
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
|
175
|
-
"""
|
|
176
|
-
Update metrics with new predictions and ground truth.
|
|
169
|
+
"""Update metrics with new predictions and ground truth.
|
|
177
170
|
|
|
178
171
|
Args:
|
|
179
172
|
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
|
@@ -185,7 +178,7 @@ class DetectionValidator(BaseValidator):
|
|
|
185
178
|
predn = self._prepare_pred(pred)
|
|
186
179
|
|
|
187
180
|
cls = pbatch["cls"].cpu().numpy()
|
|
188
|
-
no_pred =
|
|
181
|
+
no_pred = predn["cls"].shape[0] == 0
|
|
189
182
|
self.metrics.update_stats(
|
|
190
183
|
{
|
|
191
184
|
**self._process_batch(predn, pbatch),
|
|
@@ -226,9 +219,30 @@ class DetectionValidator(BaseValidator):
|
|
|
226
219
|
self.metrics.confusion_matrix = self.confusion_matrix
|
|
227
220
|
self.metrics.save_dir = self.save_dir
|
|
228
221
|
|
|
222
|
+
def gather_stats(self) -> None:
|
|
223
|
+
"""Gather stats from all GPUs."""
|
|
224
|
+
if RANK == 0:
|
|
225
|
+
gathered_stats = [None] * dist.get_world_size()
|
|
226
|
+
dist.gather_object(self.metrics.stats, gathered_stats, dst=0)
|
|
227
|
+
merged_stats = {key: [] for key in self.metrics.stats.keys()}
|
|
228
|
+
for stats_dict in gathered_stats:
|
|
229
|
+
for key in merged_stats:
|
|
230
|
+
merged_stats[key].extend(stats_dict[key])
|
|
231
|
+
gathered_jdict = [None] * dist.get_world_size()
|
|
232
|
+
dist.gather_object(self.jdict, gathered_jdict, dst=0)
|
|
233
|
+
self.jdict = []
|
|
234
|
+
for jdict in gathered_jdict:
|
|
235
|
+
self.jdict.extend(jdict)
|
|
236
|
+
self.metrics.stats = merged_stats
|
|
237
|
+
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
|
238
|
+
elif RANK > 0:
|
|
239
|
+
dist.gather_object(self.metrics.stats, None, dst=0)
|
|
240
|
+
dist.gather_object(self.jdict, None, dst=0)
|
|
241
|
+
self.jdict = []
|
|
242
|
+
self.metrics.clear_stats()
|
|
243
|
+
|
|
229
244
|
def get_stats(self) -> dict[str, Any]:
|
|
230
|
-
"""
|
|
231
|
-
Calculate and return metrics statistics.
|
|
245
|
+
"""Calculate and return metrics statistics.
|
|
232
246
|
|
|
233
247
|
Returns:
|
|
234
248
|
(dict[str, Any]): Dictionary containing metrics results.
|
|
@@ -242,7 +256,7 @@ class DetectionValidator(BaseValidator):
|
|
|
242
256
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
|
243
257
|
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
|
|
244
258
|
if self.metrics.nt_per_class.sum() == 0:
|
|
245
|
-
LOGGER.warning(f"no labels found in {self.args.task} set,
|
|
259
|
+
LOGGER.warning(f"no labels found in {self.args.task} set, cannot compute metrics without labels")
|
|
246
260
|
|
|
247
261
|
# Print results per class
|
|
248
262
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
|
|
@@ -258,24 +272,23 @@ class DetectionValidator(BaseValidator):
|
|
|
258
272
|
)
|
|
259
273
|
|
|
260
274
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
261
|
-
"""
|
|
262
|
-
Return correct prediction matrix.
|
|
275
|
+
"""Return correct prediction matrix.
|
|
263
276
|
|
|
264
277
|
Args:
|
|
265
278
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
|
266
279
|
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
|
267
280
|
|
|
268
281
|
Returns:
|
|
269
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
282
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
283
|
+
10 IoU levels.
|
|
270
284
|
"""
|
|
271
|
-
if
|
|
272
|
-
return {"tp": np.zeros((
|
|
285
|
+
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
286
|
+
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
|
273
287
|
iou = box_iou(batch["bboxes"], preds["bboxes"])
|
|
274
288
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
275
289
|
|
|
276
290
|
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
|
277
|
-
"""
|
|
278
|
-
Build YOLO Dataset.
|
|
291
|
+
"""Build YOLO Dataset.
|
|
279
292
|
|
|
280
293
|
Args:
|
|
281
294
|
img_path (str): Path to the folder containing images.
|
|
@@ -288,24 +301,28 @@ class DetectionValidator(BaseValidator):
|
|
|
288
301
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
|
289
302
|
|
|
290
303
|
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
291
|
-
"""
|
|
292
|
-
Construct and return dataloader.
|
|
304
|
+
"""Construct and return dataloader.
|
|
293
305
|
|
|
294
306
|
Args:
|
|
295
307
|
dataset_path (str): Path to the dataset.
|
|
296
308
|
batch_size (int): Size of each batch.
|
|
297
309
|
|
|
298
310
|
Returns:
|
|
299
|
-
(torch.utils.data.DataLoader):
|
|
311
|
+
(torch.utils.data.DataLoader): DataLoader for validation.
|
|
300
312
|
"""
|
|
301
313
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
|
302
314
|
return build_dataloader(
|
|
303
|
-
dataset,
|
|
315
|
+
dataset,
|
|
316
|
+
batch_size,
|
|
317
|
+
self.args.workers,
|
|
318
|
+
shuffle=False,
|
|
319
|
+
rank=-1,
|
|
320
|
+
drop_last=self.args.compile,
|
|
321
|
+
pin_memory=self.training,
|
|
304
322
|
)
|
|
305
323
|
|
|
306
324
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
307
|
-
"""
|
|
308
|
-
Plot validation image samples.
|
|
325
|
+
"""Plot validation image samples.
|
|
309
326
|
|
|
310
327
|
Args:
|
|
311
328
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -322,8 +339,7 @@ class DetectionValidator(BaseValidator):
|
|
|
322
339
|
def plot_predictions(
|
|
323
340
|
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
|
324
341
|
) -> None:
|
|
325
|
-
"""
|
|
326
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
342
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
327
343
|
|
|
328
344
|
Args:
|
|
329
345
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -331,14 +347,14 @@ class DetectionValidator(BaseValidator):
|
|
|
331
347
|
ni (int): Batch index.
|
|
332
348
|
max_det (Optional[int]): Maximum number of detections to plot.
|
|
333
349
|
"""
|
|
334
|
-
|
|
350
|
+
if not preds:
|
|
351
|
+
return
|
|
335
352
|
for i, pred in enumerate(preds):
|
|
336
353
|
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
|
|
337
354
|
keys = preds[0].keys()
|
|
338
355
|
max_det = max_det or self.args.max_det
|
|
339
356
|
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
|
|
340
|
-
#
|
|
341
|
-
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
|
|
357
|
+
batched_preds["bboxes"] = ops.xyxy2xywh(batched_preds["bboxes"]) # convert to xywh format
|
|
342
358
|
plot_images(
|
|
343
359
|
images=batch["img"],
|
|
344
360
|
labels=batched_preds,
|
|
@@ -349,8 +365,7 @@ class DetectionValidator(BaseValidator):
|
|
|
349
365
|
) # pred
|
|
350
366
|
|
|
351
367
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
352
|
-
"""
|
|
353
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
368
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
354
369
|
|
|
355
370
|
Args:
|
|
356
371
|
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
|
@@ -368,12 +383,11 @@ class DetectionValidator(BaseValidator):
|
|
|
368
383
|
).save_txt(file, save_conf=save_conf)
|
|
369
384
|
|
|
370
385
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
371
|
-
"""
|
|
372
|
-
Serialize YOLO predictions to COCO json format.
|
|
386
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
373
387
|
|
|
374
388
|
Args:
|
|
375
|
-
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
376
|
-
|
|
389
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
390
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
377
391
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
378
392
|
|
|
379
393
|
Examples:
|
|
@@ -414,8 +428,7 @@ class DetectionValidator(BaseValidator):
|
|
|
414
428
|
}
|
|
415
429
|
|
|
416
430
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
417
|
-
"""
|
|
418
|
-
Evaluate YOLO output in JSON format and return performance statistics.
|
|
431
|
+
"""Evaluate YOLO output in JSON format and return performance statistics.
|
|
419
432
|
|
|
420
433
|
Args:
|
|
421
434
|
stats (dict[str, Any]): Current statistics dictionary.
|
|
@@ -439,21 +452,20 @@ class DetectionValidator(BaseValidator):
|
|
|
439
452
|
iou_types: str | list[str] = "bbox",
|
|
440
453
|
suffix: str | list[str] = "Box",
|
|
441
454
|
) -> dict[str, Any]:
|
|
442
|
-
"""
|
|
443
|
-
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
455
|
+
"""Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
444
456
|
|
|
445
|
-
Performs evaluation using the faster-coco-eval library to compute mAP metrics
|
|
446
|
-
|
|
447
|
-
|
|
457
|
+
Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
|
|
458
|
+
provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
|
|
459
|
+
applicable.
|
|
448
460
|
|
|
449
461
|
Args:
|
|
450
462
|
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
|
451
|
-
pred_json (str | Path
|
|
452
|
-
anno_json (str | Path
|
|
453
|
-
iou_types (str | list[str]
|
|
454
|
-
|
|
455
|
-
suffix (str | list[str]
|
|
456
|
-
|
|
463
|
+
pred_json (str | Path): Path to JSON file containing predictions in COCO format.
|
|
464
|
+
anno_json (str | Path): Path to JSON file containing ground truth annotations in COCO format.
|
|
465
|
+
iou_types (str | list[str]): IoU type(s) for evaluation. Can be single string or list of strings. Common
|
|
466
|
+
values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
|
467
|
+
suffix (str | list[str]): Suffix to append to metric names in stats dictionary. Should correspond to
|
|
468
|
+
iou_types if multiple types provided. Defaults to "Box".
|
|
457
469
|
|
|
458
470
|
Returns:
|
|
459
471
|
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
ultralytics/models/yolo/model.py
CHANGED
|
@@ -24,8 +24,7 @@ from ultralytics.utils import ROOT, YAML
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class YOLO(Model):
|
|
27
|
-
"""
|
|
28
|
-
YOLO (You Only Look Once) object detection model.
|
|
27
|
+
"""YOLO (You Only Look Once) object detection model.
|
|
29
28
|
|
|
30
29
|
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
|
31
30
|
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
|
@@ -41,7 +40,7 @@ class YOLO(Model):
|
|
|
41
40
|
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
|
42
41
|
|
|
43
42
|
Examples:
|
|
44
|
-
Load a pretrained
|
|
43
|
+
Load a pretrained YOLO11n detection model
|
|
45
44
|
>>> model = YOLO("yolo11n.pt")
|
|
46
45
|
|
|
47
46
|
Load a pretrained YOLO11n segmentation model
|
|
@@ -52,22 +51,16 @@ class YOLO(Model):
|
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
53
|
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
|
55
|
-
"""
|
|
56
|
-
Initialize a YOLO model.
|
|
54
|
+
"""Initialize a YOLO model.
|
|
57
55
|
|
|
58
|
-
This constructor initializes a YOLO model, automatically switching to specialized model types
|
|
59
|
-
|
|
56
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
|
|
57
|
+
YOLOE) based on the model filename.
|
|
60
58
|
|
|
61
59
|
Args:
|
|
62
60
|
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
|
63
|
-
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
|
64
|
-
|
|
61
|
+
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
|
|
62
|
+
to auto-detection based on model.
|
|
65
63
|
verbose (bool): Display model info on load.
|
|
66
|
-
|
|
67
|
-
Examples:
|
|
68
|
-
>>> from ultralytics import YOLO
|
|
69
|
-
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
|
70
|
-
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
|
|
71
64
|
"""
|
|
72
65
|
path = Path(model if isinstance(model, (str, Path)) else "")
|
|
73
66
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
|
@@ -126,12 +119,11 @@ class YOLO(Model):
|
|
|
126
119
|
|
|
127
120
|
|
|
128
121
|
class YOLOWorld(Model):
|
|
129
|
-
"""
|
|
130
|
-
YOLO-World object detection model.
|
|
122
|
+
"""YOLO-World object detection model.
|
|
131
123
|
|
|
132
|
-
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
|
133
|
-
|
|
134
|
-
|
|
124
|
+
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
|
|
125
|
+
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
|
|
126
|
+
detection.
|
|
135
127
|
|
|
136
128
|
Attributes:
|
|
137
129
|
model: The loaded YOLO-World model instance.
|
|
@@ -152,11 +144,10 @@ class YOLOWorld(Model):
|
|
|
152
144
|
"""
|
|
153
145
|
|
|
154
146
|
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
|
155
|
-
"""
|
|
156
|
-
Initialize YOLOv8-World model with a pre-trained model file.
|
|
147
|
+
"""Initialize YOLOv8-World model with a pre-trained model file.
|
|
157
148
|
|
|
158
|
-
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
|
159
|
-
|
|
149
|
+
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
|
|
150
|
+
class names.
|
|
160
151
|
|
|
161
152
|
Args:
|
|
162
153
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -181,8 +172,7 @@ class YOLOWorld(Model):
|
|
|
181
172
|
}
|
|
182
173
|
|
|
183
174
|
def set_classes(self, classes: list[str]) -> None:
|
|
184
|
-
"""
|
|
185
|
-
Set the model's class names for detection.
|
|
175
|
+
"""Set the model's class names for detection.
|
|
186
176
|
|
|
187
177
|
Args:
|
|
188
178
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -200,11 +190,10 @@ class YOLOWorld(Model):
|
|
|
200
190
|
|
|
201
191
|
|
|
202
192
|
class YOLOE(Model):
|
|
203
|
-
"""
|
|
204
|
-
YOLOE object detection and segmentation model.
|
|
193
|
+
"""YOLOE object detection and segmentation model.
|
|
205
194
|
|
|
206
|
-
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
|
207
|
-
|
|
195
|
+
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
|
|
196
|
+
performance and additional features like visual and text positional embeddings.
|
|
208
197
|
|
|
209
198
|
Attributes:
|
|
210
199
|
model: The loaded YOLOE model instance.
|
|
@@ -235,8 +224,7 @@ class YOLOE(Model):
|
|
|
235
224
|
"""
|
|
236
225
|
|
|
237
226
|
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
|
238
|
-
"""
|
|
239
|
-
Initialize YOLOE model with a pre-trained model file.
|
|
227
|
+
"""Initialize YOLOE model with a pre-trained model file.
|
|
240
228
|
|
|
241
229
|
Args:
|
|
242
230
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -269,11 +257,10 @@ class YOLOE(Model):
|
|
|
269
257
|
return self.model.get_text_pe(texts)
|
|
270
258
|
|
|
271
259
|
def get_visual_pe(self, img, visual):
|
|
272
|
-
"""
|
|
273
|
-
Get visual positional embeddings for the given image and visual features.
|
|
260
|
+
"""Get visual positional embeddings for the given image and visual features.
|
|
274
261
|
|
|
275
|
-
This method extracts positional embeddings from visual features based on the input image. It requires
|
|
276
|
-
|
|
262
|
+
This method extracts positional embeddings from visual features based on the input image. It requires that the
|
|
263
|
+
model is an instance of YOLOEModel.
|
|
277
264
|
|
|
278
265
|
Args:
|
|
279
266
|
img (torch.Tensor): Input image tensor.
|
|
@@ -292,11 +279,10 @@ class YOLOE(Model):
|
|
|
292
279
|
return self.model.get_visual_pe(img, visual)
|
|
293
280
|
|
|
294
281
|
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
|
295
|
-
"""
|
|
296
|
-
Set vocabulary and class names for the YOLOE model.
|
|
282
|
+
"""Set vocabulary and class names for the YOLOE model.
|
|
297
283
|
|
|
298
|
-
This method configures the vocabulary and class names used by the model for text processing and
|
|
299
|
-
|
|
284
|
+
This method configures the vocabulary and class names used by the model for text processing and classification
|
|
285
|
+
tasks. The model must be an instance of YOLOEModel.
|
|
300
286
|
|
|
301
287
|
Args:
|
|
302
288
|
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
|
@@ -318,8 +304,7 @@ class YOLOE(Model):
|
|
|
318
304
|
return self.model.get_vocab(names)
|
|
319
305
|
|
|
320
306
|
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
|
321
|
-
"""
|
|
322
|
-
Set the model's class names and embeddings for detection.
|
|
307
|
+
"""Set the model's class names and embeddings for detection.
|
|
323
308
|
|
|
324
309
|
Args:
|
|
325
310
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -344,8 +329,7 @@ class YOLOE(Model):
|
|
|
344
329
|
refer_data: str | None = None,
|
|
345
330
|
**kwargs,
|
|
346
331
|
):
|
|
347
|
-
"""
|
|
348
|
-
Validate the model using text or visual prompts.
|
|
332
|
+
"""Validate the model using text or visual prompts.
|
|
349
333
|
|
|
350
334
|
Args:
|
|
351
335
|
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
|
@@ -373,19 +357,18 @@ class YOLOE(Model):
|
|
|
373
357
|
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
|
374
358
|
**kwargs,
|
|
375
359
|
):
|
|
376
|
-
"""
|
|
377
|
-
Run prediction on images, videos, directories, streams, etc.
|
|
360
|
+
"""Run prediction on images, videos, directories, streams, etc.
|
|
378
361
|
|
|
379
362
|
Args:
|
|
380
|
-
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
|
381
|
-
|
|
382
|
-
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
|
383
|
-
|
|
384
|
-
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
|
385
|
-
|
|
363
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
|
|
364
|
+
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
|
365
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
|
|
366
|
+
are computed.
|
|
367
|
+
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
|
368
|
+
and 'cls' keys when non-empty.
|
|
386
369
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
|
387
|
-
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
|
388
|
-
|
|
370
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
|
|
371
|
+
based on the task.
|
|
389
372
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
|
390
373
|
|
|
391
374
|
Returns:
|
|
@@ -416,6 +399,7 @@ class YOLOE(Model):
|
|
|
416
399
|
"batch": 1,
|
|
417
400
|
"device": kwargs.get("device", None),
|
|
418
401
|
"half": kwargs.get("half", False),
|
|
402
|
+
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
|
|
419
403
|
},
|
|
420
404
|
_callbacks=self.callbacks,
|
|
421
405
|
)
|
|
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class OBBPredictor(DetectionPredictor):
|
|
11
|
-
"""
|
|
12
|
-
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
11
|
+
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
13
12
|
|
|
14
13
|
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
|
15
14
|
bounding boxes.
|
|
@@ -27,30 +26,22 @@ class OBBPredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize OBBPredictor with optional model and data configuration overrides.
|
|
29
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
|
32
30
|
|
|
33
31
|
Args:
|
|
34
32
|
cfg (dict, optional): Default configuration for the predictor.
|
|
35
33
|
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
|
36
34
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
37
|
-
|
|
38
|
-
Examples:
|
|
39
|
-
>>> from ultralytics.utils import ASSETS
|
|
40
|
-
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
|
41
|
-
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
|
42
|
-
>>> predictor = OBBPredictor(overrides=args)
|
|
43
35
|
"""
|
|
44
36
|
super().__init__(cfg, overrides, _callbacks)
|
|
45
37
|
self.args.task = "obb"
|
|
46
38
|
|
|
47
39
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
48
|
-
"""
|
|
49
|
-
Construct the result object from the prediction.
|
|
40
|
+
"""Construct the result object from the prediction.
|
|
50
41
|
|
|
51
42
|
Args:
|
|
52
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
|
53
|
-
|
|
43
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
|
44
|
+
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
|
54
45
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
|
55
46
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
56
47
|
img_path (str): The path to the original image.
|
|
@@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
17
16
|
|
|
18
|
-
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
|
19
|
-
|
|
17
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
|
18
|
+
objects at arbitrary angles rather than just axis-aligned rectangles.
|
|
20
19
|
|
|
21
20
|
Attributes:
|
|
22
|
-
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
|
23
|
-
|
|
21
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
|
22
|
+
dfl_loss.
|
|
24
23
|
|
|
25
24
|
Methods:
|
|
26
25
|
get_model: Return OBBModel initialized with specified config and weights.
|
|
@@ -34,14 +33,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
35
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
|
37
|
-
"""
|
|
38
|
-
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
36
|
+
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
|
-
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
|
42
|
-
|
|
43
|
-
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
|
44
|
-
|
|
39
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
|
40
|
+
configuration.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
|
42
|
+
take precedence over those in cfg.
|
|
45
43
|
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
|
46
44
|
"""
|
|
47
45
|
if overrides is None:
|
|
@@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
52
50
|
def get_model(
|
|
53
51
|
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
|
54
52
|
) -> OBBModel:
|
|
55
|
-
"""
|
|
56
|
-
Return OBBModel initialized with specified config and weights.
|
|
53
|
+
"""Return OBBModel initialized with specified config and weights.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|