dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
|
@@ -24,8 +24,7 @@ from ultralytics.utils import ROOT, YAML
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class YOLO(Model):
|
|
27
|
-
"""
|
|
28
|
-
YOLO (You Only Look Once) object detection model.
|
|
27
|
+
"""YOLO (You Only Look Once) object detection model.
|
|
29
28
|
|
|
30
29
|
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
|
31
30
|
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
|
@@ -41,33 +40,27 @@ class YOLO(Model):
|
|
|
41
40
|
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
|
42
41
|
|
|
43
42
|
Examples:
|
|
44
|
-
Load a pretrained
|
|
45
|
-
>>> model = YOLO("
|
|
43
|
+
Load a pretrained YOLO26n detection model
|
|
44
|
+
>>> model = YOLO("yolo26n.pt")
|
|
46
45
|
|
|
47
|
-
Load a pretrained
|
|
48
|
-
>>> model = YOLO("
|
|
46
|
+
Load a pretrained YOLO26n segmentation model
|
|
47
|
+
>>> model = YOLO("yolo26n-seg.pt")
|
|
49
48
|
|
|
50
49
|
Initialize from a YAML configuration
|
|
51
|
-
>>> model = YOLO("
|
|
50
|
+
>>> model = YOLO("yolo26n.yaml")
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
|
-
def __init__(self, model: str | Path = "
|
|
55
|
-
"""
|
|
56
|
-
Initialize a YOLO model.
|
|
53
|
+
def __init__(self, model: str | Path = "yolo26n.pt", task: str | None = None, verbose: bool = False):
|
|
54
|
+
"""Initialize a YOLO model.
|
|
57
55
|
|
|
58
|
-
This constructor initializes a YOLO model, automatically switching to specialized model types
|
|
59
|
-
|
|
56
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
|
|
57
|
+
YOLOE) based on the model filename.
|
|
60
58
|
|
|
61
59
|
Args:
|
|
62
|
-
model (str | Path): Model name or path to model file, i.e. '
|
|
63
|
-
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
|
64
|
-
|
|
60
|
+
model (str | Path): Model name or path to model file, i.e. 'yolo26n.pt', 'yolo26n.yaml'.
|
|
61
|
+
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
|
|
62
|
+
to auto-detection based on model.
|
|
65
63
|
verbose (bool): Display model info on load.
|
|
66
|
-
|
|
67
|
-
Examples:
|
|
68
|
-
>>> from ultralytics import YOLO
|
|
69
|
-
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
|
70
|
-
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
|
|
71
64
|
"""
|
|
72
65
|
path = Path(model if isinstance(model, (str, Path)) else "")
|
|
73
66
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
|
@@ -126,12 +119,11 @@ class YOLO(Model):
|
|
|
126
119
|
|
|
127
120
|
|
|
128
121
|
class YOLOWorld(Model):
|
|
129
|
-
"""
|
|
130
|
-
YOLO-World object detection model.
|
|
122
|
+
"""YOLO-World object detection model.
|
|
131
123
|
|
|
132
|
-
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
|
133
|
-
|
|
134
|
-
|
|
124
|
+
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
|
|
125
|
+
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
|
|
126
|
+
detection.
|
|
135
127
|
|
|
136
128
|
Attributes:
|
|
137
129
|
model: The loaded YOLO-World model instance.
|
|
@@ -152,11 +144,10 @@ class YOLOWorld(Model):
|
|
|
152
144
|
"""
|
|
153
145
|
|
|
154
146
|
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
|
155
|
-
"""
|
|
156
|
-
Initialize YOLOv8-World model with a pre-trained model file.
|
|
147
|
+
"""Initialize YOLOv8-World model with a pre-trained model file.
|
|
157
148
|
|
|
158
|
-
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
|
159
|
-
|
|
149
|
+
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
|
|
150
|
+
class names.
|
|
160
151
|
|
|
161
152
|
Args:
|
|
162
153
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -181,8 +172,7 @@ class YOLOWorld(Model):
|
|
|
181
172
|
}
|
|
182
173
|
|
|
183
174
|
def set_classes(self, classes: list[str]) -> None:
|
|
184
|
-
"""
|
|
185
|
-
Set the model's class names for detection.
|
|
175
|
+
"""Set the model's class names for detection.
|
|
186
176
|
|
|
187
177
|
Args:
|
|
188
178
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -200,11 +190,10 @@ class YOLOWorld(Model):
|
|
|
200
190
|
|
|
201
191
|
|
|
202
192
|
class YOLOE(Model):
|
|
203
|
-
"""
|
|
204
|
-
YOLOE object detection and segmentation model.
|
|
193
|
+
"""YOLOE object detection and segmentation model.
|
|
205
194
|
|
|
206
|
-
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
|
207
|
-
|
|
195
|
+
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
|
|
196
|
+
performance and additional features like visual and text positional embeddings.
|
|
208
197
|
|
|
209
198
|
Attributes:
|
|
210
199
|
model: The loaded YOLOE model instance.
|
|
@@ -235,8 +224,7 @@ class YOLOE(Model):
|
|
|
235
224
|
"""
|
|
236
225
|
|
|
237
226
|
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
|
238
|
-
"""
|
|
239
|
-
Initialize YOLOE model with a pre-trained model file.
|
|
227
|
+
"""Initialize YOLOE model with a pre-trained model file.
|
|
240
228
|
|
|
241
229
|
Args:
|
|
242
230
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -269,11 +257,10 @@ class YOLOE(Model):
|
|
|
269
257
|
return self.model.get_text_pe(texts)
|
|
270
258
|
|
|
271
259
|
def get_visual_pe(self, img, visual):
|
|
272
|
-
"""
|
|
273
|
-
Get visual positional embeddings for the given image and visual features.
|
|
260
|
+
"""Get visual positional embeddings for the given image and visual features.
|
|
274
261
|
|
|
275
|
-
This method extracts positional embeddings from visual features based on the input image. It requires
|
|
276
|
-
|
|
262
|
+
This method extracts positional embeddings from visual features based on the input image. It requires that the
|
|
263
|
+
model is an instance of YOLOEModel.
|
|
277
264
|
|
|
278
265
|
Args:
|
|
279
266
|
img (torch.Tensor): Input image tensor.
|
|
@@ -292,11 +279,10 @@ class YOLOE(Model):
|
|
|
292
279
|
return self.model.get_visual_pe(img, visual)
|
|
293
280
|
|
|
294
281
|
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
|
295
|
-
"""
|
|
296
|
-
Set vocabulary and class names for the YOLOE model.
|
|
282
|
+
"""Set vocabulary and class names for the YOLOE model.
|
|
297
283
|
|
|
298
|
-
This method configures the vocabulary and class names used by the model for text processing and
|
|
299
|
-
|
|
284
|
+
This method configures the vocabulary and class names used by the model for text processing and classification
|
|
285
|
+
tasks. The model must be an instance of YOLOEModel.
|
|
300
286
|
|
|
301
287
|
Args:
|
|
302
288
|
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
|
@@ -318,8 +304,7 @@ class YOLOE(Model):
|
|
|
318
304
|
return self.model.get_vocab(names)
|
|
319
305
|
|
|
320
306
|
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
|
321
|
-
"""
|
|
322
|
-
Set the model's class names and embeddings for detection.
|
|
307
|
+
"""Set the model's class names and embeddings for detection.
|
|
323
308
|
|
|
324
309
|
Args:
|
|
325
310
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -344,8 +329,7 @@ class YOLOE(Model):
|
|
|
344
329
|
refer_data: str | None = None,
|
|
345
330
|
**kwargs,
|
|
346
331
|
):
|
|
347
|
-
"""
|
|
348
|
-
Validate the model using text or visual prompts.
|
|
332
|
+
"""Validate the model using text or visual prompts.
|
|
349
333
|
|
|
350
334
|
Args:
|
|
351
335
|
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
|
@@ -373,19 +357,18 @@ class YOLOE(Model):
|
|
|
373
357
|
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
|
374
358
|
**kwargs,
|
|
375
359
|
):
|
|
376
|
-
"""
|
|
377
|
-
Run prediction on images, videos, directories, streams, etc.
|
|
360
|
+
"""Run prediction on images, videos, directories, streams, etc.
|
|
378
361
|
|
|
379
362
|
Args:
|
|
380
|
-
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
|
381
|
-
|
|
382
|
-
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
|
383
|
-
|
|
384
|
-
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
|
385
|
-
|
|
363
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
|
|
364
|
+
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
|
365
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
|
|
366
|
+
are computed.
|
|
367
|
+
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
|
368
|
+
and 'cls' keys when non-empty.
|
|
386
369
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
|
387
|
-
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
|
388
|
-
|
|
370
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
|
|
371
|
+
based on the task.
|
|
389
372
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
|
390
373
|
|
|
391
374
|
Returns:
|
|
@@ -416,7 +399,7 @@ class YOLOE(Model):
|
|
|
416
399
|
"batch": 1,
|
|
417
400
|
"device": kwargs.get("device", None),
|
|
418
401
|
"half": kwargs.get("half", False),
|
|
419
|
-
"imgsz": kwargs.get("imgsz", self.overrides
|
|
402
|
+
"imgsz": kwargs.get("imgsz", self.overrides.get("imgsz", 640)),
|
|
420
403
|
},
|
|
421
404
|
_callbacks=self.callbacks,
|
|
422
405
|
)
|
|
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class OBBPredictor(DetectionPredictor):
|
|
11
|
-
"""
|
|
12
|
-
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
11
|
+
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
13
12
|
|
|
14
13
|
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
|
15
14
|
bounding boxes.
|
|
@@ -21,36 +20,28 @@ class OBBPredictor(DetectionPredictor):
|
|
|
21
20
|
Examples:
|
|
22
21
|
>>> from ultralytics.utils import ASSETS
|
|
23
22
|
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
|
24
|
-
>>> args = dict(model="
|
|
23
|
+
>>> args = dict(model="yolo26n-obb.pt", source=ASSETS)
|
|
25
24
|
>>> predictor = OBBPredictor(overrides=args)
|
|
26
25
|
>>> predictor.predict_cli()
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize OBBPredictor with optional model and data configuration overrides.
|
|
29
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
|
32
30
|
|
|
33
31
|
Args:
|
|
34
32
|
cfg (dict, optional): Default configuration for the predictor.
|
|
35
33
|
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
|
36
34
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
37
|
-
|
|
38
|
-
Examples:
|
|
39
|
-
>>> from ultralytics.utils import ASSETS
|
|
40
|
-
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
|
41
|
-
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
|
42
|
-
>>> predictor = OBBPredictor(overrides=args)
|
|
43
35
|
"""
|
|
44
36
|
super().__init__(cfg, overrides, _callbacks)
|
|
45
37
|
self.args.task = "obb"
|
|
46
38
|
|
|
47
39
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
48
|
-
"""
|
|
49
|
-
Construct the result object from the prediction.
|
|
40
|
+
"""Construct the result object from the prediction.
|
|
50
41
|
|
|
51
42
|
Args:
|
|
52
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
|
53
|
-
|
|
43
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
|
44
|
+
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
|
54
45
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
|
55
46
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
56
47
|
img_path (str): The path to the original image.
|
|
@@ -59,7 +50,7 @@ class OBBPredictor(DetectionPredictor):
|
|
|
59
50
|
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
|
60
51
|
boxes.
|
|
61
52
|
"""
|
|
62
|
-
rboxes =
|
|
53
|
+
rboxes = torch.cat([pred[:, :4], pred[:, -1:]], dim=-1)
|
|
63
54
|
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
|
64
55
|
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
|
65
56
|
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
|
|
@@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
17
16
|
|
|
18
|
-
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
|
19
|
-
|
|
17
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
|
18
|
+
objects at arbitrary angles rather than just axis-aligned rectangles.
|
|
20
19
|
|
|
21
20
|
Attributes:
|
|
22
|
-
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
|
23
|
-
|
|
21
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
|
22
|
+
dfl_loss.
|
|
24
23
|
|
|
25
24
|
Methods:
|
|
26
25
|
get_model: Return OBBModel initialized with specified config and weights.
|
|
@@ -28,20 +27,19 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
28
27
|
|
|
29
28
|
Examples:
|
|
30
29
|
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
|
31
|
-
>>> args = dict(model="
|
|
30
|
+
>>> args = dict(model="yolo26n-obb.pt", data="dota8.yaml", epochs=3)
|
|
32
31
|
>>> trainer = OBBTrainer(overrides=args)
|
|
33
32
|
>>> trainer.train()
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
35
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
|
37
|
-
"""
|
|
38
|
-
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
36
|
+
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
|
-
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
|
42
|
-
|
|
43
|
-
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
|
44
|
-
|
|
39
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
|
40
|
+
configuration.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
|
42
|
+
take precedence over those in cfg.
|
|
45
43
|
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
|
46
44
|
"""
|
|
47
45
|
if overrides is None:
|
|
@@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
52
50
|
def get_model(
|
|
53
51
|
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
|
54
52
|
) -> OBBModel:
|
|
55
|
-
"""
|
|
56
|
-
Return OBBModel initialized with specified config and weights.
|
|
53
|
+
"""Return OBBModel initialized with specified config and weights.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
|
@@ -66,7 +63,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
66
63
|
|
|
67
64
|
Examples:
|
|
68
65
|
>>> trainer = OBBTrainer()
|
|
69
|
-
>>> model = trainer.get_model(cfg="
|
|
66
|
+
>>> model = trainer.get_model(cfg="yolo26n-obb.yaml", weights="yolo26n-obb.pt")
|
|
70
67
|
"""
|
|
71
68
|
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
|
72
69
|
if weights:
|
|
@@ -76,7 +73,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
76
73
|
|
|
77
74
|
def get_validator(self):
|
|
78
75
|
"""Return an instance of OBBValidator for validation of YOLO model."""
|
|
79
|
-
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
|
76
|
+
self.loss_names = "box_loss", "cls_loss", "dfl_loss", "angle_loss"
|
|
80
77
|
return yolo.obb.OBBValidator(
|
|
81
78
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
82
79
|
)
|
|
@@ -12,11 +12,11 @@ from ultralytics.models.yolo.detect import DetectionValidator
|
|
|
12
12
|
from ultralytics.utils import LOGGER, ops
|
|
13
13
|
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
|
14
14
|
from ultralytics.utils.nms import TorchNMS
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class OBBValidator(DetectionValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
19
|
+
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
20
20
|
|
|
21
21
|
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
|
22
22
|
satellite imagery where objects can appear at various orientations.
|
|
@@ -38,20 +38,19 @@ class OBBValidator(DetectionValidator):
|
|
|
38
38
|
|
|
39
39
|
Examples:
|
|
40
40
|
>>> from ultralytics.models.yolo.obb import OBBValidator
|
|
41
|
-
>>> args = dict(model="
|
|
41
|
+
>>> args = dict(model="yolo26n-obb.pt", data="dota8.yaml")
|
|
42
42
|
>>> validator = OBBValidator(args=args)
|
|
43
43
|
>>> validator(model=args["model"])
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
46
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
47
|
+
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
49
48
|
|
|
50
|
-
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
|
51
|
-
|
|
49
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
|
50
|
+
extends the DetectionValidator class and configures it specifically for the OBB task.
|
|
52
51
|
|
|
53
52
|
Args:
|
|
54
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
53
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
55
54
|
save_dir (str | Path, optional): Directory to save results.
|
|
56
55
|
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
|
57
56
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
@@ -61,8 +60,7 @@ class OBBValidator(DetectionValidator):
|
|
|
61
60
|
self.metrics = OBBMetrics()
|
|
62
61
|
|
|
63
62
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize evaluation metrics for YOLO obb validation.
|
|
63
|
+
"""Initialize evaluation metrics for YOLO obb validation.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
model (torch.nn.Module): Model to validate.
|
|
@@ -73,19 +71,18 @@ class OBBValidator(DetectionValidator):
|
|
|
73
71
|
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
|
74
72
|
|
|
75
73
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
76
|
-
"""
|
|
77
|
-
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
74
|
+
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
78
75
|
|
|
79
76
|
Args:
|
|
80
77
|
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
|
81
78
|
class labels and bounding boxes.
|
|
82
|
-
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
|
83
|
-
|
|
79
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
|
80
|
+
labels and bounding boxes.
|
|
84
81
|
|
|
85
82
|
Returns:
|
|
86
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
|
84
|
+
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
|
85
|
+
predictions compared to the ground truth.
|
|
89
86
|
|
|
90
87
|
Examples:
|
|
91
88
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
|
@@ -99,7 +96,8 @@ class OBBValidator(DetectionValidator):
|
|
|
99
96
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
100
97
|
|
|
101
98
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
102
|
-
"""
|
|
99
|
+
"""Postprocess OBB predictions.
|
|
100
|
+
|
|
103
101
|
Args:
|
|
104
102
|
preds (torch.Tensor): Raw predictions from the model.
|
|
105
103
|
|
|
@@ -112,8 +110,7 @@ class OBBValidator(DetectionValidator):
|
|
|
112
110
|
return preds
|
|
113
111
|
|
|
114
112
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
-
"""
|
|
116
|
-
Prepare batch data for OBB validation with proper scaling and formatting.
|
|
113
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
|
117
114
|
|
|
118
115
|
Args:
|
|
119
116
|
si (int): Batch index to process.
|
|
@@ -145,33 +142,41 @@ class OBBValidator(DetectionValidator):
|
|
|
145
142
|
"im_file": batch["im_file"][si],
|
|
146
143
|
}
|
|
147
144
|
|
|
148
|
-
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
145
|
+
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
146
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
151
147
|
|
|
152
148
|
Args:
|
|
153
149
|
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
|
154
|
-
preds (list[torch.Tensor]): List of prediction
|
|
150
|
+
preds (list[dict[str, torch.Tensor]]): List of prediction dictionaries for each image in the batch.
|
|
155
151
|
ni (int): Batch index used for naming the output file.
|
|
156
152
|
|
|
157
153
|
Examples:
|
|
158
154
|
>>> validator = OBBValidator()
|
|
159
155
|
>>> batch = {"img": images, "im_file": paths}
|
|
160
|
-
>>> preds = [torch.rand(10,
|
|
156
|
+
>>> preds = [{"bboxes": torch.rand(10, 5), "cls": torch.zeros(10), "conf": torch.rand(10)}]
|
|
161
157
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
162
158
|
"""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
if not preds:
|
|
160
|
+
return
|
|
161
|
+
for i, pred in enumerate(preds):
|
|
162
|
+
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i
|
|
163
|
+
keys = preds[0].keys()
|
|
164
|
+
batched_preds = {k: torch.cat([x[k] for x in preds], dim=0) for k in keys}
|
|
165
|
+
plot_images(
|
|
166
|
+
images=batch["img"],
|
|
167
|
+
labels=batched_preds,
|
|
168
|
+
paths=batch["im_file"],
|
|
169
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
170
|
+
names=self.names,
|
|
171
|
+
on_plot=self.on_plot,
|
|
172
|
+
)
|
|
167
173
|
|
|
168
174
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
175
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
171
176
|
|
|
172
177
|
Args:
|
|
173
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
174
|
-
|
|
178
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
179
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
175
180
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
176
181
|
|
|
177
182
|
Notes:
|
|
@@ -197,8 +202,7 @@ class OBBValidator(DetectionValidator):
|
|
|
197
202
|
)
|
|
198
203
|
|
|
199
204
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
200
|
-
"""
|
|
201
|
-
Save YOLO OBB detections to a text file in normalized coordinates.
|
|
205
|
+
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
|
202
206
|
|
|
203
207
|
Args:
|
|
204
208
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
|
@@ -233,8 +237,7 @@ class OBBValidator(DetectionValidator):
|
|
|
233
237
|
}
|
|
234
238
|
|
|
235
239
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
236
|
-
"""
|
|
237
|
-
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
240
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
stats (dict[str, Any]): Performance statistics dictionary.
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
|
4
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
4
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class PosePredictor(DetectionPredictor):
|
|
8
|
-
"""
|
|
9
|
-
A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
8
|
+
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
10
9
|
|
|
11
10
|
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
|
12
11
|
capabilities inherited from DetectionPredictor.
|
|
@@ -21,41 +20,27 @@ class PosePredictor(DetectionPredictor):
|
|
|
21
20
|
Examples:
|
|
22
21
|
>>> from ultralytics.utils import ASSETS
|
|
23
22
|
>>> from ultralytics.models.yolo.pose import PosePredictor
|
|
24
|
-
>>> args = dict(model="
|
|
23
|
+
>>> args = dict(model="yolo26n-pose.pt", source=ASSETS)
|
|
25
24
|
>>> predictor = PosePredictor(overrides=args)
|
|
26
25
|
>>> predictor.predict_cli()
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize PosePredictor for pose estimation tasks.
|
|
29
|
+
"""Initialize PosePredictor for pose estimation tasks.
|
|
32
30
|
|
|
33
|
-
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
|
34
|
-
|
|
31
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
|
32
|
+
for Apple MPS.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
35
|
cfg (Any): Configuration for the predictor.
|
|
38
36
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
39
37
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
40
|
-
|
|
41
|
-
Examples:
|
|
42
|
-
>>> from ultralytics.utils import ASSETS
|
|
43
|
-
>>> from ultralytics.models.yolo.pose import PosePredictor
|
|
44
|
-
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
|
45
|
-
>>> predictor = PosePredictor(overrides=args)
|
|
46
|
-
>>> predictor.predict_cli()
|
|
47
38
|
"""
|
|
48
39
|
super().__init__(cfg, overrides, _callbacks)
|
|
49
40
|
self.args.task = "pose"
|
|
50
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
51
|
-
LOGGER.warning(
|
|
52
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
53
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
54
|
-
)
|
|
55
41
|
|
|
56
42
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
57
|
-
"""
|
|
58
|
-
Construct the result object from the prediction, including keypoints.
|
|
43
|
+
"""Construct the result object from the prediction, including keypoints.
|
|
59
44
|
|
|
60
45
|
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
|
61
46
|
result object.
|