ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
4
5
|
|
5
6
|
from ultralytics.data.build import load_inference_source
|
6
7
|
from ultralytics.engine.model import Model
|
@@ -19,9 +20,34 @@ from ultralytics.utils import ROOT, YAML
|
|
19
20
|
|
20
21
|
|
21
22
|
class YOLO(Model):
|
22
|
-
"""
|
23
|
+
"""
|
24
|
+
YOLO (You Only Look Once) object detection model.
|
23
25
|
|
24
|
-
|
26
|
+
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
27
|
+
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
28
|
+
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
|
29
|
+
|
30
|
+
Attributes:
|
31
|
+
model: The loaded YOLO model instance.
|
32
|
+
task: The task type (detect, segment, classify, pose, obb).
|
33
|
+
overrides: Configuration overrides for the model.
|
34
|
+
|
35
|
+
Methods:
|
36
|
+
__init__: Initialize a YOLO model with automatic type detection.
|
37
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
38
|
+
|
39
|
+
Examples:
|
40
|
+
Load a pretrained YOLOv11n detection model
|
41
|
+
>>> model = YOLO("yolo11n.pt")
|
42
|
+
|
43
|
+
Load a pretrained YOLO11n segmentation model
|
44
|
+
>>> model = YOLO("yolo11n-seg.pt")
|
45
|
+
|
46
|
+
Initialize from a YAML configuration
|
47
|
+
>>> model = YOLO("yolo11n.yaml")
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, model: Union[str, Path] = "yolo11n.pt", task: Optional[str] = None, verbose: bool = False):
|
25
51
|
"""
|
26
52
|
Initialize a YOLO model.
|
27
53
|
|
@@ -30,7 +56,7 @@ class YOLO(Model):
|
|
30
56
|
|
31
57
|
Args:
|
32
58
|
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
33
|
-
task (str
|
59
|
+
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
34
60
|
Defaults to auto-detection based on model.
|
35
61
|
verbose (bool): Display model info on load.
|
36
62
|
|
@@ -59,7 +85,7 @@ class YOLO(Model):
|
|
59
85
|
self.__dict__ = new_instance.__dict__
|
60
86
|
|
61
87
|
@property
|
62
|
-
def task_map(self):
|
88
|
+
def task_map(self) -> Dict[str, Dict[str, Any]]:
|
63
89
|
"""Map head to model, trainer, validator, and predictor classes."""
|
64
90
|
return {
|
65
91
|
"classify": {
|
@@ -96,9 +122,32 @@ class YOLO(Model):
|
|
96
122
|
|
97
123
|
|
98
124
|
class YOLOWorld(Model):
|
99
|
-
"""
|
125
|
+
"""
|
126
|
+
YOLO-World object detection model.
|
100
127
|
|
101
|
-
|
128
|
+
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
129
|
+
without requiring training on specific classes. It extends the YOLO architecture to support real-time
|
130
|
+
open-vocabulary detection.
|
131
|
+
|
132
|
+
Attributes:
|
133
|
+
model: The loaded YOLO-World model instance.
|
134
|
+
task: Always set to 'detect' for object detection.
|
135
|
+
overrides: Configuration overrides for the model.
|
136
|
+
|
137
|
+
Methods:
|
138
|
+
__init__: Initialize YOLOv8-World model with a pre-trained model file.
|
139
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
140
|
+
set_classes: Set the model's class names for detection.
|
141
|
+
|
142
|
+
Examples:
|
143
|
+
Load a YOLOv8-World model
|
144
|
+
>>> model = YOLOWorld("yolov8s-world.pt")
|
145
|
+
|
146
|
+
Set custom classes for detection
|
147
|
+
>>> model.set_classes(["person", "car", "bicycle"])
|
148
|
+
"""
|
149
|
+
|
150
|
+
def __init__(self, model: Union[str, Path] = "yolov8s-world.pt", verbose: bool = False) -> None:
|
102
151
|
"""
|
103
152
|
Initialize YOLOv8-World model with a pre-trained model file.
|
104
153
|
|
@@ -116,7 +165,7 @@ class YOLOWorld(Model):
|
|
116
165
|
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
117
166
|
|
118
167
|
@property
|
119
|
-
def task_map(self):
|
168
|
+
def task_map(self) -> Dict[str, Dict[str, Any]]:
|
120
169
|
"""Map head to model, validator, and predictor classes."""
|
121
170
|
return {
|
122
171
|
"detect": {
|
@@ -127,12 +176,12 @@ class YOLOWorld(Model):
|
|
127
176
|
}
|
128
177
|
}
|
129
178
|
|
130
|
-
def set_classes(self, classes):
|
179
|
+
def set_classes(self, classes: List[str]) -> None:
|
131
180
|
"""
|
132
181
|
Set the model's class names for detection.
|
133
182
|
|
134
183
|
Args:
|
135
|
-
classes (
|
184
|
+
classes (List[str]): A list of categories i.e. ["person"].
|
136
185
|
"""
|
137
186
|
self.model.set_classes(classes)
|
138
187
|
# Remove background if it's given
|
@@ -147,9 +196,43 @@ class YOLOWorld(Model):
|
|
147
196
|
|
148
197
|
|
149
198
|
class YOLOE(Model):
|
150
|
-
"""
|
151
|
-
|
152
|
-
|
199
|
+
"""
|
200
|
+
YOLOE object detection and segmentation model.
|
201
|
+
|
202
|
+
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
203
|
+
improved performance and additional features like visual and text positional embeddings.
|
204
|
+
|
205
|
+
Attributes:
|
206
|
+
model: The loaded YOLOE model instance.
|
207
|
+
task: The task type (detect or segment).
|
208
|
+
overrides: Configuration overrides for the model.
|
209
|
+
|
210
|
+
Methods:
|
211
|
+
__init__: Initialize YOLOE model with a pre-trained model file.
|
212
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
213
|
+
get_text_pe: Get text positional embeddings for the given texts.
|
214
|
+
get_visual_pe: Get visual positional embeddings for the given image and visual features.
|
215
|
+
set_vocab: Set vocabulary and class names for the YOLOE model.
|
216
|
+
get_vocab: Get vocabulary for the given class names.
|
217
|
+
set_classes: Set the model's class names and embeddings for detection.
|
218
|
+
val: Validate the model using text or visual prompts.
|
219
|
+
predict: Run prediction on images, videos, directories, streams, etc.
|
220
|
+
|
221
|
+
Examples:
|
222
|
+
Load a YOLOE detection model
|
223
|
+
>>> model = YOLOE("yoloe-11s-seg.pt")
|
224
|
+
|
225
|
+
Set vocabulary and class names
|
226
|
+
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
227
|
+
|
228
|
+
Predict with visual prompts
|
229
|
+
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
230
|
+
>>> results = model.predict("image.jpg", visual_prompts=prompts)
|
231
|
+
"""
|
232
|
+
|
233
|
+
def __init__(
|
234
|
+
self, model: Union[str, Path] = "yoloe-11s-seg.pt", task: Optional[str] = None, verbose: bool = False
|
235
|
+
) -> None:
|
153
236
|
"""
|
154
237
|
Initialize YOLOE model with a pre-trained model file.
|
155
238
|
|
@@ -165,7 +248,7 @@ class YOLOE(Model):
|
|
165
248
|
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
166
249
|
|
167
250
|
@property
|
168
|
-
def task_map(self):
|
251
|
+
def task_map(self) -> Dict[str, Dict[str, Any]]:
|
169
252
|
"""Map head to model, validator, and predictor classes."""
|
170
253
|
return {
|
171
254
|
"detect": {
|
@@ -210,7 +293,7 @@ class YOLOE(Model):
|
|
210
293
|
assert isinstance(self.model, YOLOEModel)
|
211
294
|
return self.model.get_visual_pe(img, visual)
|
212
295
|
|
213
|
-
def set_vocab(self, vocab, names):
|
296
|
+
def set_vocab(self, vocab: List[str], names: List[str]) -> None:
|
214
297
|
"""
|
215
298
|
Set vocabulary and class names for the YOLOE model.
|
216
299
|
|
@@ -218,8 +301,8 @@ class YOLOE(Model):
|
|
218
301
|
classification tasks. The model must be an instance of YOLOEModel.
|
219
302
|
|
220
303
|
Args:
|
221
|
-
vocab (
|
222
|
-
names (
|
304
|
+
vocab (List[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
305
|
+
names (List[str]): List of class names that the model can detect or classify.
|
223
306
|
|
224
307
|
Raises:
|
225
308
|
AssertionError: If the model is not an instance of YOLOEModel.
|
@@ -236,12 +319,12 @@ class YOLOE(Model):
|
|
236
319
|
assert isinstance(self.model, YOLOEModel)
|
237
320
|
return self.model.get_vocab(names)
|
238
321
|
|
239
|
-
def set_classes(self, classes, embeddings):
|
322
|
+
def set_classes(self, classes: List[str], embeddings) -> None:
|
240
323
|
"""
|
241
324
|
Set the model's class names and embeddings for detection.
|
242
325
|
|
243
326
|
Args:
|
244
|
-
classes (
|
327
|
+
classes (List[str]): A list of categories i.e. ["person"].
|
245
328
|
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
246
329
|
"""
|
247
330
|
assert isinstance(self.model, YOLOEModel)
|
@@ -257,8 +340,8 @@ class YOLOE(Model):
|
|
257
340
|
def val(
|
258
341
|
self,
|
259
342
|
validator=None,
|
260
|
-
load_vp=False,
|
261
|
-
refer_data=None,
|
343
|
+
load_vp: bool = False,
|
344
|
+
refer_data: Optional[str] = None,
|
262
345
|
**kwargs,
|
263
346
|
):
|
264
347
|
"""
|
@@ -285,7 +368,7 @@ class YOLOE(Model):
|
|
285
368
|
self,
|
286
369
|
source=None,
|
287
370
|
stream: bool = False,
|
288
|
-
visual_prompts:
|
371
|
+
visual_prompts: Dict[str, List] = {},
|
289
372
|
refer_image=None,
|
290
373
|
predictor=None,
|
291
374
|
**kwargs,
|
@@ -298,8 +381,8 @@ class YOLOE(Model):
|
|
298
381
|
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
299
382
|
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
300
383
|
generator as they are computed.
|
301
|
-
visual_prompts (
|
302
|
-
'cls' keys when non-empty.
|
384
|
+
visual_prompts (Dict[str, List]): Dictionary containing visual prompts for the model. Must include
|
385
|
+
'bboxes' and 'cls' keys when non-empty.
|
303
386
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
304
387
|
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
305
388
|
loaded based on the task.
|
@@ -30,8 +30,6 @@ class OBBPredictor(DetectionPredictor):
|
|
30
30
|
"""
|
31
31
|
Initialize OBBPredictor with optional model and data configuration overrides.
|
32
32
|
|
33
|
-
This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
|
34
|
-
|
35
33
|
Args:
|
36
34
|
cfg (dict, optional): Default configuration for the predictor.
|
37
35
|
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
@@ -51,14 +49,15 @@ class OBBPredictor(DetectionPredictor):
|
|
51
49
|
Construct the result object from the prediction.
|
52
50
|
|
53
51
|
Args:
|
54
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N,
|
52
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
55
53
|
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
56
54
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
57
55
|
orig_img (np.ndarray): The original image before preprocessing.
|
58
56
|
img_path (str): The path to the original image.
|
59
57
|
|
60
58
|
Returns:
|
61
|
-
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
59
|
+
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
60
|
+
boxes.
|
62
61
|
"""
|
63
62
|
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
64
63
|
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
@@ -1,6 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import copy
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, List, Optional, Union
|
4
6
|
|
5
7
|
from ultralytics.models import yolo
|
6
8
|
from ultralytics.nn.tasks import OBBModel
|
@@ -11,8 +13,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
11
13
|
"""
|
12
14
|
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
13
15
|
|
16
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
17
|
+
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
|
18
|
+
|
14
19
|
Attributes:
|
15
|
-
loss_names (
|
20
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
21
|
+
and dfl_loss.
|
16
22
|
|
17
23
|
Methods:
|
18
24
|
get_model: Return OBBModel initialized with specified config and weights.
|
@@ -25,7 +31,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
25
31
|
>>> trainer.train()
|
26
32
|
"""
|
27
33
|
|
28
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
34
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[dict] = None, _callbacks: Optional[List[Any]] = None):
|
29
35
|
"""
|
30
36
|
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
31
37
|
|
@@ -37,7 +43,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
37
43
|
model configuration.
|
38
44
|
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
39
45
|
will take precedence over those in cfg.
|
40
|
-
_callbacks (
|
46
|
+
_callbacks (List[Any], optional): List of callback functions to be invoked during training.
|
41
47
|
|
42
48
|
Examples:
|
43
49
|
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
@@ -50,14 +56,16 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
50
56
|
overrides["task"] = "obb"
|
51
57
|
super().__init__(cfg, overrides, _callbacks)
|
52
58
|
|
53
|
-
def get_model(
|
59
|
+
def get_model(
|
60
|
+
self, cfg: Optional[Union[str, dict]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True
|
61
|
+
) -> OBBModel:
|
54
62
|
"""
|
55
63
|
Return OBBModel initialized with specified config and weights.
|
56
64
|
|
57
65
|
Args:
|
58
|
-
cfg (str | dict
|
66
|
+
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
59
67
|
containing configuration parameters, or None to use default configuration.
|
60
|
-
weights (str | Path
|
68
|
+
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
|
61
69
|
verbose (bool): Whether to display model information during initialization.
|
62
70
|
|
63
71
|
Returns:
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
+
from typing import Dict, List, Tuple, Union
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
@@ -63,34 +64,31 @@ class OBBValidator(DetectionValidator):
|
|
63
64
|
val = self.data.get(self.args.split, "") # validation path
|
64
65
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
65
66
|
|
66
|
-
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
67
|
+
def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
|
67
68
|
"""
|
68
|
-
|
69
|
+
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
69
70
|
|
70
71
|
Args:
|
71
|
-
detections (torch.Tensor):
|
72
|
-
|
73
|
-
gt_bboxes (torch.Tensor):
|
74
|
-
|
75
|
-
gt_cls (torch.Tensor):
|
72
|
+
detections (torch.Tensor): Detected bounding boxes and associated data with shape (N, 7) where each
|
73
|
+
detection is represented as (x1, y1, x2, y2, conf, class, angle).
|
74
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (M, 5) where each box is represented
|
75
|
+
as (x1, y1, x2, y2, angle).
|
76
|
+
gt_cls (torch.Tensor): Class labels for the ground truth bounding boxes with shape (M,).
|
76
77
|
|
77
78
|
Returns:
|
78
|
-
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU
|
79
|
-
|
79
|
+
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU levels for each
|
80
|
+
detection, indicating the accuracy of predictions compared to the ground truth.
|
80
81
|
|
81
82
|
Examples:
|
82
83
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
83
84
|
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
84
85
|
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
85
|
-
>>> correct_matrix =
|
86
|
-
|
87
|
-
Note:
|
88
|
-
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
|
86
|
+
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
89
87
|
"""
|
90
88
|
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
91
89
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
92
90
|
|
93
|
-
def _prepare_batch(self, si, batch):
|
91
|
+
def _prepare_batch(self, si: int, batch: Dict) -> Dict:
|
94
92
|
"""
|
95
93
|
Prepare batch data for OBB validation with proper scaling and formatting.
|
96
94
|
|
@@ -104,8 +102,8 @@ class OBBValidator(DetectionValidator):
|
|
104
102
|
- img: Batch of images
|
105
103
|
- ratio_pad: Ratio and padding information
|
106
104
|
|
107
|
-
|
108
|
-
|
105
|
+
Returns:
|
106
|
+
(dict): Prepared batch data with scaled bounding boxes and metadata.
|
109
107
|
"""
|
110
108
|
idx = batch["batch_idx"] == si
|
111
109
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -118,7 +116,7 @@ class OBBValidator(DetectionValidator):
|
|
118
116
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
119
117
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
120
118
|
|
121
|
-
def _prepare_pred(self, pred, pbatch):
|
119
|
+
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict) -> torch.Tensor:
|
122
120
|
"""
|
123
121
|
Prepare predictions by scaling bounding boxes to original image dimensions.
|
124
122
|
|
@@ -141,7 +139,7 @@ class OBBValidator(DetectionValidator):
|
|
141
139
|
) # native-space pred
|
142
140
|
return predn
|
143
141
|
|
144
|
-
def plot_predictions(self, batch, preds, ni):
|
142
|
+
def plot_predictions(self, batch: Dict, preds: List[torch.Tensor], ni: int):
|
145
143
|
"""
|
146
144
|
Plot predicted bounding boxes on input images and save the result.
|
147
145
|
|
@@ -165,7 +163,7 @@ class OBBValidator(DetectionValidator):
|
|
165
163
|
on_plot=self.on_plot,
|
166
164
|
) # pred
|
167
165
|
|
168
|
-
def pred_to_json(self, predn, filename):
|
166
|
+
def pred_to_json(self, predn: torch.Tensor, filename: Union[str, Path]):
|
169
167
|
"""
|
170
168
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
171
169
|
|
@@ -194,9 +192,9 @@ class OBBValidator(DetectionValidator):
|
|
194
192
|
}
|
195
193
|
)
|
196
194
|
|
197
|
-
def save_one_txt(self, predn, save_conf, shape, file):
|
195
|
+
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Union[Path, str]):
|
198
196
|
"""
|
199
|
-
Save YOLO OBB
|
197
|
+
Save YOLO OBB detections to a text file in normalized coordinates.
|
200
198
|
|
201
199
|
Args:
|
202
200
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
@@ -224,8 +222,16 @@ class OBBValidator(DetectionValidator):
|
|
224
222
|
obb=obb,
|
225
223
|
).save_txt(file, save_conf=save_conf)
|
226
224
|
|
227
|
-
def eval_json(self, stats):
|
228
|
-
"""
|
225
|
+
def eval_json(self, stats: Dict) -> Dict:
|
226
|
+
"""
|
227
|
+
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
stats (dict): Performance statistics dictionary.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
(dict): Updated performance statistics.
|
234
|
+
"""
|
229
235
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
230
236
|
import json
|
231
237
|
import re
|
@@ -16,7 +16,7 @@ class PosePredictor(DetectionPredictor):
|
|
16
16
|
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
17
17
|
|
18
18
|
Methods:
|
19
|
-
construct_result:
|
19
|
+
construct_result: Construct the result object from the prediction, including keypoints.
|
20
20
|
|
21
21
|
Examples:
|
22
22
|
>>> from ultralytics.utils import ASSETS
|
@@ -28,13 +28,13 @@ class PosePredictor(DetectionPredictor):
|
|
28
28
|
|
29
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
30
30
|
"""
|
31
|
-
Initialize PosePredictor
|
31
|
+
Initialize PosePredictor for pose estimation tasks.
|
32
32
|
|
33
|
-
|
34
|
-
|
33
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
34
|
+
warnings for Apple MPS.
|
35
35
|
|
36
36
|
Args:
|
37
|
-
cfg (Any): Configuration for the predictor.
|
37
|
+
cfg (Any): Configuration for the predictor.
|
38
38
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
39
39
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
40
40
|
|
@@ -57,8 +57,8 @@ class PosePredictor(DetectionPredictor):
|
|
57
57
|
"""
|
58
58
|
Construct the result object from the prediction, including keypoints.
|
59
59
|
|
60
|
-
|
61
|
-
|
60
|
+
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
61
|
+
result object.
|
62
62
|
|
63
63
|
Args:
|
64
64
|
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
@@ -68,7 +68,8 @@ class PosePredictor(DetectionPredictor):
|
|
68
68
|
img_path (str): The path to the original image file.
|
69
69
|
|
70
70
|
Returns:
|
71
|
-
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
71
|
+
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
72
|
+
keypoints.
|
72
73
|
"""
|
73
74
|
result = super().construct_result(pred, img, orig_img, img_path)
|
74
75
|
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
@@ -1,6 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import copy
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Dict, Optional, Union
|
4
6
|
|
5
7
|
from ultralytics.models import yolo
|
6
8
|
from ultralytics.nn.tasks import PoseModel
|
@@ -19,14 +21,15 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
19
21
|
args (dict): Configuration arguments for training.
|
20
22
|
model (PoseModel): The pose estimation model being trained.
|
21
23
|
data (dict): Dataset configuration including keypoint shape information.
|
22
|
-
loss_names (
|
24
|
+
loss_names (tuple): Names of the loss components used in training.
|
23
25
|
|
24
26
|
Methods:
|
25
|
-
get_model:
|
26
|
-
set_model_attributes:
|
27
|
-
get_validator:
|
28
|
-
plot_training_samples:
|
29
|
-
plot_metrics:
|
27
|
+
get_model: Retrieve a pose estimation model with specified configuration.
|
28
|
+
set_model_attributes: Set keypoints shape attribute on the model.
|
29
|
+
get_validator: Create a validator instance for model evaluation.
|
30
|
+
plot_training_samples: Visualize training samples with keypoints.
|
31
|
+
plot_metrics: Generate and save training/validation metric plots.
|
32
|
+
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
30
33
|
|
31
34
|
Examples:
|
32
35
|
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
@@ -35,7 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
35
38
|
>>> trainer.train()
|
36
39
|
"""
|
37
40
|
|
38
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
41
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
|
39
42
|
"""
|
40
43
|
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
41
44
|
|
@@ -68,13 +71,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
68
71
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
69
72
|
)
|
70
73
|
|
71
|
-
def get_model(
|
74
|
+
def get_model(
|
75
|
+
self,
|
76
|
+
cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,
|
77
|
+
weights: Optional[Union[str, Path]] = None,
|
78
|
+
verbose: bool = True,
|
79
|
+
) -> PoseModel:
|
72
80
|
"""
|
73
81
|
Get pose estimation model with specified configuration and weights.
|
74
82
|
|
75
83
|
Args:
|
76
|
-
cfg (str | Path | dict
|
77
|
-
weights (str | Path
|
84
|
+
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
85
|
+
weights (str | Path, optional): Path to the model weights file.
|
78
86
|
verbose (bool): Whether to display model information.
|
79
87
|
|
80
88
|
Returns:
|
@@ -89,18 +97,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
89
97
|
return model
|
90
98
|
|
91
99
|
def set_model_attributes(self):
|
92
|
-
"""
|
100
|
+
"""Set keypoints shape attribute of PoseModel."""
|
93
101
|
super().set_model_attributes()
|
94
102
|
self.model.kpt_shape = self.data["kpt_shape"]
|
95
103
|
|
96
104
|
def get_validator(self):
|
97
|
-
"""
|
105
|
+
"""Return an instance of the PoseValidator class for validation."""
|
98
106
|
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
99
107
|
return yolo.pose.PoseValidator(
|
100
108
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
101
109
|
)
|
102
110
|
|
103
|
-
def plot_training_samples(self, batch, ni):
|
111
|
+
def plot_training_samples(self, batch: Dict[str, Any], ni: int):
|
104
112
|
"""
|
105
113
|
Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
|
106
114
|
|
@@ -135,12 +143,12 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
135
143
|
)
|
136
144
|
|
137
145
|
def plot_metrics(self):
|
138
|
-
"""
|
146
|
+
"""Plot training/validation metrics."""
|
139
147
|
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
|
140
148
|
|
141
|
-
def get_dataset(self):
|
149
|
+
def get_dataset(self) -> Dict[str, Any]:
|
142
150
|
"""
|
143
|
-
|
151
|
+
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
144
152
|
|
145
153
|
Returns:
|
146
154
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|