dgenerate-ultralytics-headless 8.3.194__py3-none-any.whl → 8.3.196__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/METADATA +1 -2
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/RECORD +107 -106
- tests/test_python.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +9 -8
- ultralytics/cfg/default.yaml +1 -0
- ultralytics/data/annotator.py +1 -1
- ultralytics/data/augment.py +76 -76
- ultralytics/data/base.py +12 -12
- ultralytics/data/build.py +5 -1
- ultralytics/data/converter.py +4 -4
- ultralytics/data/dataset.py +7 -7
- ultralytics/data/loaders.py +15 -15
- ultralytics/data/split_dota.py +10 -10
- ultralytics/data/utils.py +12 -12
- ultralytics/engine/exporter.py +19 -31
- ultralytics/engine/model.py +13 -13
- ultralytics/engine/predictor.py +16 -14
- ultralytics/engine/results.py +21 -21
- ultralytics/engine/trainer.py +15 -4
- ultralytics/engine/validator.py +6 -2
- ultralytics/hub/google/__init__.py +2 -2
- ultralytics/hub/session.py +7 -7
- ultralytics/models/fastsam/model.py +5 -5
- ultralytics/models/fastsam/predict.py +11 -11
- ultralytics/models/nas/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +2 -2
- ultralytics/models/rtdetr/val.py +4 -4
- ultralytics/models/sam/amg.py +6 -6
- ultralytics/models/sam/build.py +9 -9
- ultralytics/models/sam/model.py +7 -7
- ultralytics/models/sam/modules/blocks.py +6 -6
- ultralytics/models/sam/modules/decoders.py +1 -1
- ultralytics/models/sam/modules/encoders.py +27 -27
- ultralytics/models/sam/modules/sam.py +4 -4
- ultralytics/models/sam/modules/tiny_encoder.py +18 -18
- ultralytics/models/sam/modules/utils.py +8 -8
- ultralytics/models/sam/predict.py +63 -63
- ultralytics/models/utils/loss.py +22 -22
- ultralytics/models/utils/ops.py +8 -8
- ultralytics/models/yolo/classify/predict.py +2 -2
- ultralytics/models/yolo/classify/train.py +9 -19
- ultralytics/models/yolo/classify/val.py +4 -4
- ultralytics/models/yolo/detect/predict.py +3 -3
- ultralytics/models/yolo/detect/train.py +38 -12
- ultralytics/models/yolo/detect/val.py +38 -37
- ultralytics/models/yolo/model.py +6 -6
- ultralytics/models/yolo/obb/train.py +1 -10
- ultralytics/models/yolo/obb/val.py +13 -13
- ultralytics/models/yolo/pose/train.py +1 -9
- ultralytics/models/yolo/pose/val.py +12 -12
- ultralytics/models/yolo/segment/predict.py +4 -4
- ultralytics/models/yolo/segment/train.py +2 -10
- ultralytics/models/yolo/segment/val.py +15 -15
- ultralytics/models/yolo/world/train.py +13 -13
- ultralytics/models/yolo/world/train_world.py +3 -3
- ultralytics/models/yolo/yoloe/predict.py +4 -4
- ultralytics/models/yolo/yoloe/train.py +7 -16
- ultralytics/models/yolo/yoloe/val.py +0 -7
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/nn/modules/block.py +6 -6
- ultralytics/nn/modules/conv.py +2 -2
- ultralytics/nn/modules/head.py +6 -5
- ultralytics/nn/tasks.py +17 -15
- ultralytics/nn/text_model.py +3 -3
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +3 -3
- ultralytics/solutions/config.py +5 -5
- ultralytics/solutions/distance_calculation.py +2 -2
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +4 -4
- ultralytics/solutions/object_counter.py +4 -4
- ultralytics/solutions/parking_management.py +7 -7
- ultralytics/solutions/queue_management.py +3 -3
- ultralytics/solutions/region_counter.py +4 -4
- ultralytics/solutions/similarity_search.py +2 -2
- ultralytics/solutions/solutions.py +48 -48
- ultralytics/solutions/streamlit_inference.py +1 -1
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/solutions/vision_eye.py +1 -1
- ultralytics/trackers/byte_tracker.py +11 -11
- ultralytics/trackers/utils/gmc.py +3 -3
- ultralytics/trackers/utils/matching.py +5 -5
- ultralytics/utils/__init__.py +30 -19
- ultralytics/utils/autodevice.py +2 -2
- ultralytics/utils/benchmarks.py +10 -10
- ultralytics/utils/callbacks/clearml.py +1 -1
- ultralytics/utils/callbacks/comet.py +5 -5
- ultralytics/utils/callbacks/tensorboard.py +2 -2
- ultralytics/utils/checks.py +7 -5
- ultralytics/utils/cpu.py +90 -0
- ultralytics/utils/dist.py +1 -1
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/export.py +5 -5
- ultralytics/utils/instance.py +2 -2
- ultralytics/utils/loss.py +14 -8
- ultralytics/utils/metrics.py +35 -35
- ultralytics/utils/nms.py +4 -4
- ultralytics/utils/ops.py +1 -1
- ultralytics/utils/patches.py +2 -2
- ultralytics/utils/plotting.py +10 -9
- ultralytics/utils/torch_utils.py +113 -15
- ultralytics/utils/triton.py +5 -5
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.194.dist-info → dgenerate_ultralytics_headless-8.3.196.dist-info}/top_level.txt +0 -0
@@ -8,16 +8,17 @@ from copy import copy
|
|
8
8
|
from typing import Any
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
+
import torch
|
11
12
|
import torch.nn as nn
|
12
13
|
|
13
14
|
from ultralytics.data import build_dataloader, build_yolo_dataset
|
14
15
|
from ultralytics.engine.trainer import BaseTrainer
|
15
16
|
from ultralytics.models import yolo
|
16
17
|
from ultralytics.nn.tasks import DetectionModel
|
17
|
-
from ultralytics.utils import LOGGER, RANK
|
18
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
18
19
|
from ultralytics.utils.patches import override_configs
|
19
20
|
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
20
|
-
from ultralytics.utils.torch_utils import
|
21
|
+
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
|
21
22
|
|
22
23
|
|
23
24
|
class DetectionTrainer(BaseTrainer):
|
@@ -29,7 +30,7 @@ class DetectionTrainer(BaseTrainer):
|
|
29
30
|
|
30
31
|
Attributes:
|
31
32
|
model (DetectionModel): The YOLO detection model being trained.
|
32
|
-
data (
|
33
|
+
data (dict): Dictionary containing dataset information including class names and number of classes.
|
33
34
|
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
34
35
|
|
35
36
|
Methods:
|
@@ -53,6 +54,18 @@ class DetectionTrainer(BaseTrainer):
|
|
53
54
|
>>> trainer.train()
|
54
55
|
"""
|
55
56
|
|
57
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
58
|
+
"""
|
59
|
+
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
63
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
64
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
65
|
+
"""
|
66
|
+
super().__init__(cfg, overrides, _callbacks)
|
67
|
+
self.dynamic_tensors = ["batch_idx", "cls", "bboxes"]
|
68
|
+
|
56
69
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
57
70
|
"""
|
58
71
|
Build YOLO Dataset for training or validation.
|
@@ -65,7 +78,7 @@ class DetectionTrainer(BaseTrainer):
|
|
65
78
|
Returns:
|
66
79
|
(Dataset): YOLO dataset object configured for the specified mode.
|
67
80
|
"""
|
68
|
-
gs = max(int(
|
81
|
+
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
69
82
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
70
83
|
|
71
84
|
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
@@ -88,20 +101,29 @@ class DetectionTrainer(BaseTrainer):
|
|
88
101
|
if getattr(dataset, "rect", False) and shuffle:
|
89
102
|
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
90
103
|
shuffle = False
|
91
|
-
|
92
|
-
|
104
|
+
return build_dataloader(
|
105
|
+
dataset,
|
106
|
+
batch=batch_size,
|
107
|
+
workers=self.args.workers if mode == "train" else self.args.workers * 2,
|
108
|
+
shuffle=shuffle,
|
109
|
+
rank=rank,
|
110
|
+
drop_last=self.args.compile and mode == "train",
|
111
|
+
)
|
93
112
|
|
94
113
|
def preprocess_batch(self, batch: dict) -> dict:
|
95
114
|
"""
|
96
115
|
Preprocess a batch of images by scaling and converting to float.
|
97
116
|
|
98
117
|
Args:
|
99
|
-
batch (
|
118
|
+
batch (dict): Dictionary containing batch data with 'img' tensor.
|
100
119
|
|
101
120
|
Returns:
|
102
|
-
(
|
121
|
+
(dict): Preprocessed batch with normalized images.
|
103
122
|
"""
|
104
|
-
|
123
|
+
for k, v in batch.items():
|
124
|
+
if isinstance(v, torch.Tensor):
|
125
|
+
batch[k] = v.to(self.device, non_blocking=True)
|
126
|
+
batch["img"] = batch["img"].float() / 255
|
105
127
|
if self.args.multi_scale:
|
106
128
|
imgs = batch["img"]
|
107
129
|
sz = (
|
@@ -116,6 +138,10 @@ class DetectionTrainer(BaseTrainer):
|
|
116
138
|
] # new shape (stretched to gs-multiple)
|
117
139
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
118
140
|
batch["img"] = imgs
|
141
|
+
|
142
|
+
if self.args.compile:
|
143
|
+
for k in self.dynamic_tensors:
|
144
|
+
torch._dynamo.maybe_mark_dynamic(batch[k], 0)
|
119
145
|
return batch
|
120
146
|
|
121
147
|
def set_model_attributes(self):
|
@@ -158,11 +184,11 @@ class DetectionTrainer(BaseTrainer):
|
|
158
184
|
Return a loss dict with labeled training loss items tensor.
|
159
185
|
|
160
186
|
Args:
|
161
|
-
loss_items (
|
187
|
+
loss_items (list[float], optional): List of loss values.
|
162
188
|
prefix (str): Prefix for keys in the returned dictionary.
|
163
189
|
|
164
190
|
Returns:
|
165
|
-
(
|
191
|
+
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
|
166
192
|
"""
|
167
193
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
168
194
|
if loss_items is not None:
|
@@ -186,7 +212,7 @@ class DetectionTrainer(BaseTrainer):
|
|
186
212
|
Plot training samples with their annotations.
|
187
213
|
|
188
214
|
Args:
|
189
|
-
batch (
|
215
|
+
batch (dict[str, Any]): Dictionary containing batch data.
|
190
216
|
ni (int): Number of iterations.
|
191
217
|
"""
|
192
218
|
plot_images(
|
@@ -27,13 +27,13 @@ class DetectionValidator(BaseValidator):
|
|
27
27
|
Attributes:
|
28
28
|
is_coco (bool): Whether the dataset is COCO.
|
29
29
|
is_lvis (bool): Whether the dataset is LVIS.
|
30
|
-
class_map (
|
30
|
+
class_map (list[int]): Mapping from model class indices to dataset class indices.
|
31
31
|
metrics (DetMetrics): Object detection metrics calculator.
|
32
32
|
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
33
33
|
niou (int): Number of IoU thresholds.
|
34
|
-
lb (
|
35
|
-
jdict (
|
36
|
-
stats (
|
34
|
+
lb (list[Any]): List for storing ground truth labels for hybrid saving.
|
35
|
+
jdict (list[dict[str, Any]]): List for storing JSON detection results.
|
36
|
+
stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
|
37
37
|
|
38
38
|
Examples:
|
39
39
|
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
@@ -49,8 +49,8 @@ class DetectionValidator(BaseValidator):
|
|
49
49
|
Args:
|
50
50
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
51
51
|
save_dir (Path, optional): Directory to save results.
|
52
|
-
args (
|
53
|
-
_callbacks (
|
52
|
+
args (dict[str, Any], optional): Arguments for the validator.
|
53
|
+
_callbacks (list[Any], optional): List of callback functions.
|
54
54
|
"""
|
55
55
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
56
56
|
self.is_coco = False
|
@@ -66,16 +66,15 @@ class DetectionValidator(BaseValidator):
|
|
66
66
|
Preprocess batch of images for YOLO validation.
|
67
67
|
|
68
68
|
Args:
|
69
|
-
batch (
|
69
|
+
batch (dict[str, Any]): Batch containing images and annotations.
|
70
70
|
|
71
71
|
Returns:
|
72
|
-
(
|
72
|
+
(dict[str, Any]): Preprocessed batch.
|
73
73
|
"""
|
74
|
-
|
74
|
+
for k, v in batch.items():
|
75
|
+
if isinstance(v, torch.Tensor):
|
76
|
+
batch[k] = v.to(self.device, non_blocking=True)
|
75
77
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
76
|
-
for k in {"batch_idx", "cls", "bboxes"}:
|
77
|
-
batch[k] = batch[k].to(self.device, non_blocking=True)
|
78
|
-
|
79
78
|
return batch
|
80
79
|
|
81
80
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
@@ -114,7 +113,7 @@ class DetectionValidator(BaseValidator):
|
|
114
113
|
preds (torch.Tensor): Raw predictions from the model.
|
115
114
|
|
116
115
|
Returns:
|
117
|
-
(
|
116
|
+
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
118
117
|
'bboxes', 'conf', 'cls', and 'extra' tensors.
|
119
118
|
"""
|
120
119
|
outputs = nms.non_max_suppression(
|
@@ -136,10 +135,10 @@ class DetectionValidator(BaseValidator):
|
|
136
135
|
|
137
136
|
Args:
|
138
137
|
si (int): Batch index.
|
139
|
-
batch (
|
138
|
+
batch (dict[str, Any]): Batch data containing images and annotations.
|
140
139
|
|
141
140
|
Returns:
|
142
|
-
(
|
141
|
+
(dict[str, Any]): Prepared batch with processed annotations.
|
143
142
|
"""
|
144
143
|
idx = batch["batch_idx"] == si
|
145
144
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -163,10 +162,10 @@ class DetectionValidator(BaseValidator):
|
|
163
162
|
Prepare predictions for evaluation against ground truth.
|
164
163
|
|
165
164
|
Args:
|
166
|
-
pred (
|
165
|
+
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
167
166
|
|
168
167
|
Returns:
|
169
|
-
(
|
168
|
+
(dict[str, torch.Tensor]): Prepared predictions in native space.
|
170
169
|
"""
|
171
170
|
if self.args.single_cls:
|
172
171
|
pred["cls"] *= 0
|
@@ -177,8 +176,8 @@ class DetectionValidator(BaseValidator):
|
|
177
176
|
Update metrics with new predictions and ground truth.
|
178
177
|
|
179
178
|
Args:
|
180
|
-
preds (
|
181
|
-
batch (
|
179
|
+
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
180
|
+
batch (dict[str, Any]): Batch data containing ground truth.
|
182
181
|
"""
|
183
182
|
for si, pred in enumerate(preds):
|
184
183
|
self.seen += 1
|
@@ -232,7 +231,7 @@ class DetectionValidator(BaseValidator):
|
|
232
231
|
Calculate and return metrics statistics.
|
233
232
|
|
234
233
|
Returns:
|
235
|
-
(
|
234
|
+
(dict[str, Any]): Dictionary containing metrics results.
|
236
235
|
"""
|
237
236
|
self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
|
238
237
|
self.metrics.clear_stats()
|
@@ -263,11 +262,11 @@ class DetectionValidator(BaseValidator):
|
|
263
262
|
Return correct prediction matrix.
|
264
263
|
|
265
264
|
Args:
|
266
|
-
preds (
|
267
|
-
batch (
|
265
|
+
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
266
|
+
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
268
267
|
|
269
268
|
Returns:
|
270
|
-
(
|
269
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
271
270
|
"""
|
272
271
|
if len(batch["cls"]) == 0 or len(preds["cls"]) == 0:
|
273
272
|
return {"tp": np.zeros((len(preds["cls"]), self.niou), dtype=bool)}
|
@@ -300,14 +299,16 @@ class DetectionValidator(BaseValidator):
|
|
300
299
|
(torch.utils.data.DataLoader): Dataloader for validation.
|
301
300
|
"""
|
302
301
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
303
|
-
return build_dataloader(
|
302
|
+
return build_dataloader(
|
303
|
+
dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
|
304
|
+
)
|
304
305
|
|
305
306
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
306
307
|
"""
|
307
308
|
Plot validation image samples.
|
308
309
|
|
309
310
|
Args:
|
310
|
-
batch (
|
311
|
+
batch (dict[str, Any]): Batch containing images and annotations.
|
311
312
|
ni (int): Batch index.
|
312
313
|
"""
|
313
314
|
plot_images(
|
@@ -325,8 +326,8 @@ class DetectionValidator(BaseValidator):
|
|
325
326
|
Plot predicted bounding boxes on input images and save the result.
|
326
327
|
|
327
328
|
Args:
|
328
|
-
batch (
|
329
|
-
preds (
|
329
|
+
batch (dict[str, Any]): Batch containing images and annotations.
|
330
|
+
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
330
331
|
ni (int): Batch index.
|
331
332
|
max_det (Optional[int]): Maximum number of detections to plot.
|
332
333
|
"""
|
@@ -352,9 +353,9 @@ class DetectionValidator(BaseValidator):
|
|
352
353
|
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
353
354
|
|
354
355
|
Args:
|
355
|
-
predn (
|
356
|
+
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
356
357
|
save_conf (bool): Whether to save confidence scores.
|
357
|
-
shape (
|
358
|
+
shape (tuple[int, int]): Shape of the original image (height, width).
|
358
359
|
file (Path): File path to save the detections.
|
359
360
|
"""
|
360
361
|
from ultralytics.engine.results import Results
|
@@ -371,9 +372,9 @@ class DetectionValidator(BaseValidator):
|
|
371
372
|
Serialize YOLO predictions to COCO json format.
|
372
373
|
|
373
374
|
Args:
|
374
|
-
predn (
|
375
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
375
376
|
with bounding box coordinates, confidence scores, and class predictions.
|
376
|
-
pbatch (
|
377
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
377
378
|
|
378
379
|
Examples:
|
379
380
|
>>> result = {
|
@@ -417,10 +418,10 @@ class DetectionValidator(BaseValidator):
|
|
417
418
|
Evaluate YOLO output in JSON format and return performance statistics.
|
418
419
|
|
419
420
|
Args:
|
420
|
-
stats (
|
421
|
+
stats (dict[str, Any]): Current statistics dictionary.
|
421
422
|
|
422
423
|
Returns:
|
423
|
-
(
|
424
|
+
(dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
|
424
425
|
"""
|
425
426
|
pred_json = self.save_dir / "predictions.json" # predictions
|
426
427
|
anno_json = (
|
@@ -446,16 +447,16 @@ class DetectionValidator(BaseValidator):
|
|
446
447
|
including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
|
447
448
|
|
448
449
|
Args:
|
449
|
-
stats (
|
450
|
+
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
450
451
|
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
|
451
452
|
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
|
452
|
-
iou_types (str |
|
453
|
+
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
|
453
454
|
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
454
|
-
suffix (str |
|
455
|
+
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
|
455
456
|
to iou_types if multiple types provided. Defaults to "Box".
|
456
457
|
|
457
458
|
Returns:
|
458
|
-
(
|
459
|
+
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
459
460
|
"""
|
460
461
|
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
461
462
|
LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
|
ultralytics/models/yolo/model.py
CHANGED
@@ -185,7 +185,7 @@ class YOLOWorld(Model):
|
|
185
185
|
Set the model's class names for detection.
|
186
186
|
|
187
187
|
Args:
|
188
|
-
classes (
|
188
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
189
189
|
"""
|
190
190
|
self.model.set_classes(classes)
|
191
191
|
# Remove background if it's given
|
@@ -299,8 +299,8 @@ class YOLOE(Model):
|
|
299
299
|
classification tasks. The model must be an instance of YOLOEModel.
|
300
300
|
|
301
301
|
Args:
|
302
|
-
vocab (
|
303
|
-
names (
|
302
|
+
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
303
|
+
names (list[str]): List of class names that the model can detect or classify.
|
304
304
|
|
305
305
|
Raises:
|
306
306
|
AssertionError: If the model is not an instance of YOLOEModel.
|
@@ -322,7 +322,7 @@ class YOLOE(Model):
|
|
322
322
|
Set the model's class names and embeddings for detection.
|
323
323
|
|
324
324
|
Args:
|
325
|
-
classes (
|
325
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
326
326
|
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
327
327
|
"""
|
328
328
|
assert isinstance(self.model, YOLOEModel)
|
@@ -381,7 +381,7 @@ class YOLOE(Model):
|
|
381
381
|
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
382
382
|
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
383
383
|
generator as they are computed.
|
384
|
-
visual_prompts (
|
384
|
+
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
385
385
|
'bboxes' and 'cls' keys when non-empty.
|
386
386
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
387
387
|
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
@@ -389,7 +389,7 @@ class YOLOE(Model):
|
|
389
389
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
390
390
|
|
391
391
|
Returns:
|
392
|
-
(
|
392
|
+
(list | generator): List of Results objects or generator of Results objects if stream=True.
|
393
393
|
|
394
394
|
Examples:
|
395
395
|
>>> model = YOLOE("yoloe-11s-seg.pt")
|
@@ -37,21 +37,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
37
37
|
"""
|
38
38
|
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
39
39
|
|
40
|
-
This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
|
41
|
-
bounding boxes. It automatically sets the task to 'obb' in the configuration.
|
42
|
-
|
43
40
|
Args:
|
44
41
|
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
45
42
|
model configuration.
|
46
43
|
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
47
44
|
will take precedence over those in cfg.
|
48
|
-
_callbacks (
|
49
|
-
|
50
|
-
Examples:
|
51
|
-
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
52
|
-
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
53
|
-
>>> trainer = OBBTrainer(overrides=args)
|
54
|
-
>>> trainer.train()
|
45
|
+
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
55
46
|
"""
|
56
47
|
if overrides is None:
|
57
48
|
overrides = {}
|
@@ -77,13 +77,13 @@ class OBBValidator(DetectionValidator):
|
|
77
77
|
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
78
78
|
|
79
79
|
Args:
|
80
|
-
preds (
|
80
|
+
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
81
81
|
class labels and bounding boxes.
|
82
|
-
batch (
|
82
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
83
83
|
class labels and bounding boxes.
|
84
84
|
|
85
85
|
Returns:
|
86
|
-
(
|
86
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
87
87
|
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
|
88
88
|
of predictions compared to the ground truth.
|
89
89
|
|
@@ -104,7 +104,7 @@ class OBBValidator(DetectionValidator):
|
|
104
104
|
preds (torch.Tensor): Raw predictions from the model.
|
105
105
|
|
106
106
|
Returns:
|
107
|
-
(
|
107
|
+
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
|
108
108
|
"""
|
109
109
|
preds = super().postprocess(preds)
|
110
110
|
for pred in preds:
|
@@ -117,7 +117,7 @@ class OBBValidator(DetectionValidator):
|
|
117
117
|
|
118
118
|
Args:
|
119
119
|
si (int): Batch index to process.
|
120
|
-
batch (
|
120
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys:
|
121
121
|
- batch_idx: Tensor of batch indices
|
122
122
|
- cls: Tensor of class labels
|
123
123
|
- bboxes: Tensor of bounding boxes
|
@@ -126,7 +126,7 @@ class OBBValidator(DetectionValidator):
|
|
126
126
|
- ratio_pad: Ratio and padding information
|
127
127
|
|
128
128
|
Returns:
|
129
|
-
(
|
129
|
+
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
130
130
|
"""
|
131
131
|
idx = batch["batch_idx"] == si
|
132
132
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -150,8 +150,8 @@ class OBBValidator(DetectionValidator):
|
|
150
150
|
Plot predicted bounding boxes on input images and save the result.
|
151
151
|
|
152
152
|
Args:
|
153
|
-
batch (
|
154
|
-
preds (
|
153
|
+
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
154
|
+
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
|
155
155
|
ni (int): Batch index used for naming the output file.
|
156
156
|
|
157
157
|
Examples:
|
@@ -170,9 +170,9 @@ class OBBValidator(DetectionValidator):
|
|
170
170
|
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
171
171
|
|
172
172
|
Args:
|
173
|
-
predn (
|
173
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
174
174
|
with bounding box coordinates, confidence scores, and class predictions.
|
175
|
-
pbatch (
|
175
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
176
176
|
|
177
177
|
Notes:
|
178
178
|
This method processes rotated bounding box predictions and converts them to both rbox format
|
@@ -204,7 +204,7 @@ class OBBValidator(DetectionValidator):
|
|
204
204
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
205
205
|
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
206
206
|
save_conf (bool): Whether to save confidence scores in the text file.
|
207
|
-
shape (
|
207
|
+
shape (tuple[int, int]): Original image shape in format (height, width).
|
208
208
|
file (Path): Output file path to save detections.
|
209
209
|
|
210
210
|
Examples:
|
@@ -237,10 +237,10 @@ class OBBValidator(DetectionValidator):
|
|
237
237
|
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
238
238
|
|
239
239
|
Args:
|
240
|
-
stats (
|
240
|
+
stats (dict[str, Any]): Performance statistics dictionary.
|
241
241
|
|
242
242
|
Returns:
|
243
|
-
(
|
243
|
+
(dict[str, Any]): Updated performance statistics.
|
244
244
|
"""
|
245
245
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
246
246
|
import json
|
@@ -44,9 +44,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
44
44
|
"""
|
45
45
|
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
46
46
|
|
47
|
-
This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
|
48
|
-
handling specific configurations needed for keypoint detection models.
|
49
|
-
|
50
47
|
Args:
|
51
48
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
52
49
|
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
@@ -55,17 +52,12 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
55
52
|
Notes:
|
56
53
|
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
57
54
|
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
58
|
-
|
59
|
-
Examples:
|
60
|
-
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
61
|
-
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
62
|
-
>>> trainer = PoseTrainer(overrides=args)
|
63
|
-
>>> trainer.train()
|
64
55
|
"""
|
65
56
|
if overrides is None:
|
66
57
|
overrides = {}
|
67
58
|
overrides["task"] = "pose"
|
68
59
|
super().__init__(cfg, overrides, _callbacks)
|
60
|
+
self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "keypoints"]
|
69
61
|
|
70
62
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
71
63
|
LOGGER.warning(
|
@@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
|
|
22
22
|
|
23
23
|
Attributes:
|
24
24
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
25
|
-
kpt_shape (
|
25
|
+
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
26
26
|
args (dict): Arguments for the validator including task set to "pose".
|
27
27
|
metrics (PoseMetrics): Metrics object for pose evaluation.
|
28
28
|
|
@@ -86,7 +86,7 @@ class PoseValidator(DetectionValidator):
|
|
86
86
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
87
87
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
88
88
|
batch = super().preprocess(batch)
|
89
|
-
batch["keypoints"] = batch["keypoints"].
|
89
|
+
batch["keypoints"] = batch["keypoints"].float()
|
90
90
|
return batch
|
91
91
|
|
92
92
|
def get_desc(self) -> str:
|
@@ -132,7 +132,7 @@ class PoseValidator(DetectionValidator):
|
|
132
132
|
bounding boxes, confidence scores, class predictions, and keypoint data.
|
133
133
|
|
134
134
|
Returns:
|
135
|
-
(
|
135
|
+
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
136
136
|
- 'bboxes': Bounding box coordinates
|
137
137
|
- 'conf': Confidence scores
|
138
138
|
- 'cls': Class predictions
|
@@ -154,10 +154,10 @@ class PoseValidator(DetectionValidator):
|
|
154
154
|
|
155
155
|
Args:
|
156
156
|
si (int): Batch index.
|
157
|
-
batch (
|
157
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
158
158
|
|
159
159
|
Returns:
|
160
|
-
(
|
160
|
+
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
161
161
|
|
162
162
|
Notes:
|
163
163
|
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
@@ -177,13 +177,13 @@ class PoseValidator(DetectionValidator):
|
|
177
177
|
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
178
178
|
|
179
179
|
Args:
|
180
|
-
preds (
|
180
|
+
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
181
181
|
and 'keypoints' for keypoint predictions.
|
182
|
-
batch (
|
182
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
183
183
|
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
184
184
|
|
185
185
|
Returns:
|
186
|
-
(
|
186
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
187
187
|
true positives across 10 IoU levels.
|
188
188
|
|
189
189
|
Notes:
|
@@ -207,9 +207,9 @@ class PoseValidator(DetectionValidator):
|
|
207
207
|
Save YOLO pose detections to a text file in normalized coordinates.
|
208
208
|
|
209
209
|
Args:
|
210
|
-
predn (
|
210
|
+
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
211
211
|
save_conf (bool): Whether to save confidence scores.
|
212
|
-
shape (
|
212
|
+
shape (tuple[int, int]): Shape of the original image (height, width).
|
213
213
|
file (Path): Output file path to save detections.
|
214
214
|
|
215
215
|
Notes:
|
@@ -234,9 +234,9 @@ class PoseValidator(DetectionValidator):
|
|
234
234
|
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
235
235
|
|
236
236
|
Args:
|
237
|
-
predn (
|
237
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
238
238
|
and 'keypoints' tensors.
|
239
|
-
pbatch (
|
239
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
240
240
|
|
241
241
|
Notes:
|
242
242
|
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
@@ -71,13 +71,13 @@ class SegmentationPredictor(DetectionPredictor):
|
|
71
71
|
Construct a list of result objects from the predictions.
|
72
72
|
|
73
73
|
Args:
|
74
|
-
preds (
|
74
|
+
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
75
75
|
img (torch.Tensor): The image after preprocessing.
|
76
|
-
orig_imgs (
|
77
|
-
protos (
|
76
|
+
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
77
|
+
protos (list[torch.Tensor]): List of prototype masks.
|
78
78
|
|
79
79
|
Returns:
|
80
|
-
(
|
80
|
+
(list[Results]): List of result objects containing the original images, image paths, class names,
|
81
81
|
bounding boxes, and masks.
|
82
82
|
"""
|
83
83
|
return [
|
@@ -19,7 +19,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
19
19
|
functionality including model initialization, validation, and visualization.
|
20
20
|
|
21
21
|
Attributes:
|
22
|
-
loss_names (
|
22
|
+
loss_names (tuple[str]): Names of the loss components used during training.
|
23
23
|
|
24
24
|
Examples:
|
25
25
|
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
@@ -32,24 +32,16 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
32
32
|
"""
|
33
33
|
Initialize a SegmentationTrainer object.
|
34
34
|
|
35
|
-
This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
|
36
|
-
functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
|
37
|
-
|
38
35
|
Args:
|
39
36
|
cfg (dict): Configuration dictionary with default training settings.
|
40
37
|
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
41
38
|
_callbacks (list, optional): List of callback functions to be executed during training.
|
42
|
-
|
43
|
-
Examples:
|
44
|
-
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
45
|
-
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
46
|
-
>>> trainer = SegmentationTrainer(overrides=args)
|
47
|
-
>>> trainer.train()
|
48
39
|
"""
|
49
40
|
if overrides is None:
|
50
41
|
overrides = {}
|
51
42
|
overrides["task"] = "segment"
|
52
43
|
super().__init__(cfg, overrides, _callbacks)
|
44
|
+
self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "masks"]
|
53
45
|
|
54
46
|
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
55
47
|
"""
|