ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +125 -39
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +34 -33
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +33 -47
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +69 -90
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +31 -38
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +21 -26
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +23 -17
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +29 -24
- ultralytics/models/nas/predict.py +14 -11
- ultralytics/models/nas/val.py +11 -13
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +21 -21
- ultralytics/models/rtdetr/train.py +25 -24
- ultralytics/models/rtdetr/val.py +47 -14
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +30 -12
- ultralytics/models/yolo/classify/train.py +83 -19
- ultralytics/models/yolo/classify/val.py +45 -23
- ultralytics/models/yolo/detect/predict.py +29 -19
- ultralytics/models/yolo/detect/train.py +90 -23
- ultralytics/models/yolo/detect/val.py +150 -29
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +18 -13
- ultralytics/models/yolo/obb/train.py +12 -8
- ultralytics/models/yolo/obb/val.py +35 -22
- ultralytics/models/yolo/pose/predict.py +28 -15
- ultralytics/models/yolo/pose/train.py +21 -8
- ultralytics/models/yolo/pose/val.py +51 -31
- ultralytics/models/yolo/segment/predict.py +27 -16
- ultralytics/models/yolo/segment/train.py +11 -8
- ultralytics/models/yolo/segment/val.py +110 -29
- ultralytics/models/yolo/world/train.py +43 -16
- ultralytics/models/yolo/world/train_world.py +61 -36
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +12 -12
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +226 -79
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +37 -35
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +139 -68
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +37 -56
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +117 -52
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +65 -61
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +72 -59
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +202 -64
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +13 -25
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.88.dist-info/RECORD +0 -250
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,40 @@ class DetectionValidator(BaseValidator):
|
|
18
18
|
"""
|
19
19
|
A class extending the BaseValidator class for validation based on a detection model.
|
20
20
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
21
|
+
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
22
|
+
prediction processing, and visualization of results.
|
23
|
+
|
24
|
+
Attributes:
|
25
|
+
nt_per_class (np.ndarray): Number of targets per class.
|
26
|
+
nt_per_image (np.ndarray): Number of targets per image.
|
27
|
+
is_coco (bool): Whether the dataset is COCO.
|
28
|
+
is_lvis (bool): Whether the dataset is LVIS.
|
29
|
+
class_map (List): Mapping from model class indices to dataset class indices.
|
30
|
+
metrics (DetMetrics): Object detection metrics calculator.
|
31
|
+
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
32
|
+
niou (int): Number of IoU thresholds.
|
33
|
+
lb (List): List for storing ground truth labels for hybrid saving.
|
34
|
+
jdict (List): List for storing JSON detection results.
|
35
|
+
stats (Dict): Dictionary for storing statistics during validation.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
39
|
+
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
|
40
|
+
>>> validator = DetectionValidator(args=args)
|
41
|
+
>>> validator()
|
29
42
|
"""
|
30
43
|
|
31
44
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
32
|
-
"""
|
45
|
+
"""
|
46
|
+
Initialize detection validator with necessary variables and settings.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
50
|
+
save_dir (Path, optional): Directory to save results.
|
51
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
52
|
+
args (Dict, optional): Arguments for the validator.
|
53
|
+
_callbacks (List, optional): List of callback functions.
|
54
|
+
"""
|
33
55
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
34
56
|
self.nt_per_class = None
|
35
57
|
self.nt_per_image = None
|
@@ -48,7 +70,15 @@ class DetectionValidator(BaseValidator):
|
|
48
70
|
)
|
49
71
|
|
50
72
|
def preprocess(self, batch):
|
51
|
-
"""
|
73
|
+
"""
|
74
|
+
Preprocess batch of images for YOLO validation.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
batch (Dict): Batch containing images and annotations.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
(Dict): Preprocessed batch.
|
81
|
+
"""
|
52
82
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
53
83
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
54
84
|
for k in ["batch_idx", "cls", "bboxes"]:
|
@@ -66,7 +96,12 @@ class DetectionValidator(BaseValidator):
|
|
66
96
|
return batch
|
67
97
|
|
68
98
|
def init_metrics(self, model):
|
69
|
-
"""
|
99
|
+
"""
|
100
|
+
Initialize evaluation metrics for YOLO detection validation.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model (torch.nn.Module): Model to validate.
|
104
|
+
"""
|
70
105
|
val = self.data.get(self.args.split, "") # validation path
|
71
106
|
self.is_coco = (
|
72
107
|
isinstance(val, str)
|
@@ -91,7 +126,15 @@ class DetectionValidator(BaseValidator):
|
|
91
126
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
92
127
|
|
93
128
|
def postprocess(self, preds):
|
94
|
-
"""
|
129
|
+
"""
|
130
|
+
Apply Non-maximum suppression to prediction outputs.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
preds (torch.Tensor): Raw predictions from the model.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
(List[torch.Tensor]): Processed predictions after NMS.
|
137
|
+
"""
|
95
138
|
return ops.non_max_suppression(
|
96
139
|
preds,
|
97
140
|
self.args.conf,
|
@@ -106,7 +149,16 @@ class DetectionValidator(BaseValidator):
|
|
106
149
|
)
|
107
150
|
|
108
151
|
def _prepare_batch(self, si, batch):
|
109
|
-
"""
|
152
|
+
"""
|
153
|
+
Prepare a batch of images and annotations for validation.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
si (int): Batch index.
|
157
|
+
batch (Dict): Batch data containing images and annotations.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
(Dict): Prepared batch with processed annotations.
|
161
|
+
"""
|
110
162
|
idx = batch["batch_idx"] == si
|
111
163
|
cls = batch["cls"][idx].squeeze(-1)
|
112
164
|
bbox = batch["bboxes"][idx]
|
@@ -119,7 +171,16 @@ class DetectionValidator(BaseValidator):
|
|
119
171
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
120
172
|
|
121
173
|
def _prepare_pred(self, pred, pbatch):
|
122
|
-
"""
|
174
|
+
"""
|
175
|
+
Prepare predictions for evaluation against ground truth.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
pred (torch.Tensor): Model predictions.
|
179
|
+
pbatch (Dict): Prepared batch information.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
(torch.Tensor): Prepared predictions in native space.
|
183
|
+
"""
|
123
184
|
predn = pred.clone()
|
124
185
|
ops.scale_boxes(
|
125
186
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
@@ -127,7 +188,13 @@ class DetectionValidator(BaseValidator):
|
|
127
188
|
return predn
|
128
189
|
|
129
190
|
def update_metrics(self, preds, batch):
|
130
|
-
"""
|
191
|
+
"""
|
192
|
+
Update metrics with new predictions and ground truth.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
preds (List[torch.Tensor]): List of predictions from the model.
|
196
|
+
batch (Dict): Batch data containing ground truth.
|
197
|
+
"""
|
131
198
|
for si, pred in enumerate(preds):
|
132
199
|
self.seen += 1
|
133
200
|
npr = len(pred)
|
@@ -176,12 +243,23 @@ class DetectionValidator(BaseValidator):
|
|
176
243
|
)
|
177
244
|
|
178
245
|
def finalize_metrics(self, *args, **kwargs):
|
179
|
-
"""
|
246
|
+
"""
|
247
|
+
Set final values for metrics speed and confusion matrix.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
*args (Any): Variable length argument list.
|
251
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
252
|
+
"""
|
180
253
|
self.metrics.speed = self.speed
|
181
254
|
self.metrics.confusion_matrix = self.confusion_matrix
|
182
255
|
|
183
256
|
def get_stats(self):
|
184
|
-
"""
|
257
|
+
"""
|
258
|
+
Calculate and return metrics statistics.
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
(Dict): Dictionary containing metrics results.
|
262
|
+
"""
|
185
263
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
186
264
|
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
|
187
265
|
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
|
@@ -191,7 +269,7 @@ class DetectionValidator(BaseValidator):
|
|
191
269
|
return self.metrics.results_dict
|
192
270
|
|
193
271
|
def print_results(self):
|
194
|
-
"""
|
272
|
+
"""Print training/validation set metrics per class."""
|
195
273
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
196
274
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
197
275
|
if self.nt_per_class.sum() == 0:
|
@@ -223,10 +301,6 @@ class DetectionValidator(BaseValidator):
|
|
223
301
|
|
224
302
|
Returns:
|
225
303
|
(torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
226
|
-
|
227
|
-
Note:
|
228
|
-
The function does not return any value directly usable for metrics calculation. Instead, it provides an
|
229
|
-
intermediate representation used for evaluating predictions against ground truth.
|
230
304
|
"""
|
231
305
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
232
306
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
@@ -238,17 +312,35 @@ class DetectionValidator(BaseValidator):
|
|
238
312
|
Args:
|
239
313
|
img_path (str): Path to the folder containing images.
|
240
314
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
241
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
315
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
(Dataset): YOLO dataset.
|
242
319
|
"""
|
243
320
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
244
321
|
|
245
322
|
def get_dataloader(self, dataset_path, batch_size):
|
246
|
-
"""
|
323
|
+
"""
|
324
|
+
Construct and return dataloader.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
dataset_path (str): Path to the dataset.
|
328
|
+
batch_size (int): Size of each batch.
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
(torch.utils.data.DataLoader): Dataloader for validation.
|
332
|
+
"""
|
247
333
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
248
334
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
249
335
|
|
250
336
|
def plot_val_samples(self, batch, ni):
|
251
|
-
"""
|
337
|
+
"""
|
338
|
+
Plot validation image samples.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
batch (Dict): Batch containing images and annotations.
|
342
|
+
ni (int): Batch index.
|
343
|
+
"""
|
252
344
|
plot_images(
|
253
345
|
batch["img"],
|
254
346
|
batch["batch_idx"],
|
@@ -261,7 +353,14 @@ class DetectionValidator(BaseValidator):
|
|
261
353
|
)
|
262
354
|
|
263
355
|
def plot_predictions(self, batch, preds, ni):
|
264
|
-
"""
|
356
|
+
"""
|
357
|
+
Plot predicted bounding boxes on input images and save the result.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
batch (Dict): Batch containing images and annotations.
|
361
|
+
preds (List[torch.Tensor]): List of predictions from the model.
|
362
|
+
ni (int): Batch index.
|
363
|
+
"""
|
265
364
|
plot_images(
|
266
365
|
batch["img"],
|
267
366
|
*output_to_target(preds, max_det=self.args.max_det),
|
@@ -272,7 +371,15 @@ class DetectionValidator(BaseValidator):
|
|
272
371
|
) # pred
|
273
372
|
|
274
373
|
def save_one_txt(self, predn, save_conf, shape, file):
|
275
|
-
"""
|
374
|
+
"""
|
375
|
+
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
376
|
+
|
377
|
+
Args:
|
378
|
+
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
379
|
+
save_conf (bool): Whether to save confidence scores.
|
380
|
+
shape (tuple): Shape of the original image.
|
381
|
+
file (Path): File path to save the detections.
|
382
|
+
"""
|
276
383
|
from ultralytics.engine.results import Results
|
277
384
|
|
278
385
|
Results(
|
@@ -283,7 +390,13 @@ class DetectionValidator(BaseValidator):
|
|
283
390
|
).save_txt(file, save_conf=save_conf)
|
284
391
|
|
285
392
|
def pred_to_json(self, predn, filename):
|
286
|
-
"""
|
393
|
+
"""
|
394
|
+
Serialize YOLO predictions to COCO json format.
|
395
|
+
|
396
|
+
Args:
|
397
|
+
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
398
|
+
filename (str): Image filename.
|
399
|
+
"""
|
287
400
|
stem = Path(filename).stem
|
288
401
|
image_id = int(stem) if stem.isnumeric() else stem
|
289
402
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -299,7 +412,15 @@ class DetectionValidator(BaseValidator):
|
|
299
412
|
)
|
300
413
|
|
301
414
|
def eval_json(self, stats):
|
302
|
-
"""
|
415
|
+
"""
|
416
|
+
Evaluate YOLO output in JSON format and return performance statistics.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
stats (Dict): Current statistics dictionary.
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
(Dict): Updated statistics dictionary with COCO/LVIS evaluation results.
|
423
|
+
"""
|
303
424
|
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
304
425
|
pred_json = self.save_dir / "predictions.json" # predictions
|
305
426
|
anno_json = (
|
ultralytics/models/yolo/model.py
CHANGED
@@ -93,7 +93,7 @@ class YOLOWorld(Model):
|
|
93
93
|
|
94
94
|
def set_classes(self, classes):
|
95
95
|
"""
|
96
|
-
Set
|
96
|
+
Set the model's class names for detection.
|
97
97
|
|
98
98
|
Args:
|
99
99
|
classes (List(str)): A list of categories i.e. ["person"].
|
@@ -106,6 +106,5 @@ class YOLOWorld(Model):
|
|
106
106
|
self.model.names = classes
|
107
107
|
|
108
108
|
# Reset method class names
|
109
|
-
# self.predictor = None # reset predictor otherwise old names remain
|
110
109
|
if self.predictor:
|
111
110
|
self.predictor.model.names = classes
|
@@ -11,29 +11,34 @@ class OBBPredictor(DetectionPredictor):
|
|
11
11
|
"""
|
12
12
|
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
14
|
+
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
15
|
+
bounding boxes.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
args (namespace): Configuration arguments for the predictor.
|
19
|
+
model (torch.nn.Module): The loaded YOLO OBB model.
|
20
|
+
|
21
|
+
Examples:
|
22
|
+
>>> from ultralytics.utils import ASSETS
|
23
|
+
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
24
|
+
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
25
|
+
>>> predictor = OBBPredictor(overrides=args)
|
26
|
+
>>> predictor.predict_cli()
|
23
27
|
"""
|
24
28
|
|
25
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
26
|
-
"""
|
30
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides."""
|
27
31
|
super().__init__(cfg, overrides, _callbacks)
|
28
32
|
self.args.task = "obb"
|
29
33
|
|
30
34
|
def construct_result(self, pred, img, orig_img, img_path):
|
31
35
|
"""
|
32
|
-
|
36
|
+
Construct the result object from the prediction.
|
33
37
|
|
34
38
|
Args:
|
35
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles
|
36
|
-
|
39
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
|
40
|
+
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
41
|
+
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
37
42
|
orig_img (np.ndarray): The original image before preprocessing.
|
38
43
|
img_path (str): The path to the original image.
|
39
44
|
|
@@ -11,14 +11,18 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
11
11
|
"""
|
12
12
|
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
14
|
+
Attributes:
|
15
|
+
loss_names (Tuple[str]): Names of the loss components used during training.
|
16
|
+
|
17
|
+
Methods:
|
18
|
+
get_model: Return OBBModel initialized with specified config and weights.
|
19
|
+
get_validator: Return an instance of OBBValidator for validation of YOLO model.
|
20
|
+
|
21
|
+
Examples:
|
22
|
+
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
23
|
+
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
24
|
+
>>> trainer = OBBTrainer(overrides=args)
|
25
|
+
>>> trainer.train()
|
22
26
|
"""
|
23
27
|
|
24
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -14,14 +14,29 @@ class OBBValidator(DetectionValidator):
|
|
14
14
|
"""
|
15
15
|
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
args
|
22
|
-
|
23
|
-
|
24
|
-
|
17
|
+
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
18
|
+
satellite imagery where objects can appear at various orientations.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
args (Dict): Configuration arguments for the validator.
|
22
|
+
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
|
23
|
+
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
|
24
|
+
|
25
|
+
Methods:
|
26
|
+
init_metrics: Initialize evaluation metrics for YOLO.
|
27
|
+
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
|
28
|
+
_prepare_batch: Prepare batch data for OBB validation.
|
29
|
+
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
|
30
|
+
plot_predictions: Plot predicted bounding boxes on input images.
|
31
|
+
pred_to_json: Serialize YOLO predictions to COCO json format.
|
32
|
+
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
|
33
|
+
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
>>> from ultralytics.models.yolo.obb import OBBValidator
|
37
|
+
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
|
38
|
+
>>> validator = OBBValidator(args=args)
|
39
|
+
>>> validator(model=args["model"])
|
25
40
|
"""
|
26
41
|
|
27
42
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
@@ -34,7 +49,7 @@ class OBBValidator(DetectionValidator):
|
|
34
49
|
"""Initialize evaluation metrics for YOLO."""
|
35
50
|
super().init_metrics(model)
|
36
51
|
val = self.data.get(self.args.split, "") # validation path
|
37
|
-
self.is_dota = isinstance(val, str) and "DOTA" in val # is
|
52
|
+
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
38
53
|
|
39
54
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
40
55
|
"""
|
@@ -51,13 +66,11 @@ class OBBValidator(DetectionValidator):
|
|
51
66
|
(torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
|
52
67
|
Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
|
53
68
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
|
60
|
-
```
|
69
|
+
Examples:
|
70
|
+
>>> detections = torch.rand(100, 7) # 100 sample detections
|
71
|
+
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
72
|
+
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
73
|
+
>>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
|
61
74
|
|
62
75
|
Note:
|
63
76
|
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
|
@@ -66,7 +79,7 @@ class OBBValidator(DetectionValidator):
|
|
66
79
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
67
80
|
|
68
81
|
def _prepare_batch(self, si, batch):
|
69
|
-
"""
|
82
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting."""
|
70
83
|
idx = batch["batch_idx"] == si
|
71
84
|
cls = batch["cls"][idx].squeeze(-1)
|
72
85
|
bbox = batch["bboxes"][idx]
|
@@ -79,7 +92,7 @@ class OBBValidator(DetectionValidator):
|
|
79
92
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
80
93
|
|
81
94
|
def _prepare_pred(self, pred, pbatch):
|
82
|
-
"""
|
95
|
+
"""Prepare predictions by scaling bounding boxes to original image dimensions."""
|
83
96
|
predn = pred.clone()
|
84
97
|
ops.scale_boxes(
|
85
98
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
@@ -87,7 +100,7 @@ class OBBValidator(DetectionValidator):
|
|
87
100
|
return predn
|
88
101
|
|
89
102
|
def plot_predictions(self, batch, preds, ni):
|
90
|
-
"""
|
103
|
+
"""Plot predicted bounding boxes on input images and save the result."""
|
91
104
|
plot_images(
|
92
105
|
batch["img"],
|
93
106
|
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
@@ -98,7 +111,7 @@ class OBBValidator(DetectionValidator):
|
|
98
111
|
) # pred
|
99
112
|
|
100
113
|
def pred_to_json(self, predn, filename):
|
101
|
-
"""
|
114
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
|
102
115
|
stem = Path(filename).stem
|
103
116
|
image_id = int(stem) if stem.isnumeric() else stem
|
104
117
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
@@ -115,7 +128,7 @@ class OBBValidator(DetectionValidator):
|
|
115
128
|
)
|
116
129
|
|
117
130
|
def save_one_txt(self, predn, save_conf, shape, file):
|
118
|
-
"""Save YOLO detections to a txt file in normalized coordinates
|
131
|
+
"""Save YOLO detections to a txt file in normalized coordinates using the Results class."""
|
119
132
|
import numpy as np
|
120
133
|
|
121
134
|
from ultralytics.engine.results import Results
|
@@ -131,7 +144,7 @@ class OBBValidator(DetectionValidator):
|
|
131
144
|
).save_txt(file, save_conf=save_conf)
|
132
145
|
|
133
146
|
def eval_json(self, stats):
|
134
|
-
"""
|
147
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format."""
|
135
148
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
136
149
|
import json
|
137
150
|
import re
|
@@ -8,19 +8,26 @@ class PosePredictor(DetectionPredictor):
|
|
8
8
|
"""
|
9
9
|
A class extending the DetectionPredictor class for prediction based on a pose model.
|
10
10
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
11
|
+
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
12
|
+
capabilities inherited from DetectionPredictor.
|
13
|
+
|
14
|
+
Attributes:
|
15
|
+
args (namespace): Configuration arguments for the predictor.
|
16
|
+
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
17
|
+
|
18
|
+
Methods:
|
19
|
+
construct_result: Constructs the result object from the prediction, including keypoints.
|
20
|
+
|
21
|
+
Examples:
|
22
|
+
>>> from ultralytics.utils import ASSETS
|
23
|
+
>>> from ultralytics.models.yolo.pose import PosePredictor
|
24
|
+
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
25
|
+
>>> predictor = PosePredictor(overrides=args)
|
26
|
+
>>> predictor.predict_cli()
|
20
27
|
"""
|
21
28
|
|
22
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
23
|
-
"""
|
30
|
+
"""Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device."""
|
24
31
|
super().__init__(cfg, overrides, _callbacks)
|
25
32
|
self.args.task = "pose"
|
26
33
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
@@ -31,19 +38,25 @@ class PosePredictor(DetectionPredictor):
|
|
31
38
|
|
32
39
|
def construct_result(self, pred, img, orig_img, img_path):
|
33
40
|
"""
|
34
|
-
|
41
|
+
Construct the result object from the prediction, including keypoints.
|
42
|
+
|
43
|
+
This method extends the parent class implementation by extracting keypoint data from predictions
|
44
|
+
and adding them to the result object.
|
35
45
|
|
36
46
|
Args:
|
37
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints
|
38
|
-
|
39
|
-
|
40
|
-
|
47
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
48
|
+
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
|
49
|
+
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
|
50
|
+
orig_img (np.ndarray): The original unprocessed image as a numpy array.
|
51
|
+
img_path (str): The path to the original image file.
|
41
52
|
|
42
53
|
Returns:
|
43
54
|
(Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
|
44
55
|
"""
|
45
56
|
result = super().construct_result(pred, img, orig_img, img_path)
|
57
|
+
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
46
58
|
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
59
|
+
# Scale keypoints coordinates to match the original image dimensions
|
47
60
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
48
61
|
result.update(keypoints=pred_kpts)
|
49
62
|
return result
|
@@ -10,16 +10,29 @@ from ultralytics.utils.plotting import plot_images, plot_results
|
|
10
10
|
|
11
11
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
12
12
|
"""
|
13
|
-
A class extending the DetectionTrainer class for training
|
13
|
+
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
from ultralytics.models.yolo.pose import PoseTrainer
|
15
|
+
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
16
|
+
of pose keypoints alongside bounding boxes.
|
18
17
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
18
|
+
Attributes:
|
19
|
+
args (Dict): Configuration arguments for training.
|
20
|
+
model (PoseModel): The pose estimation model being trained.
|
21
|
+
data (Dict): Dataset configuration including keypoint shape information.
|
22
|
+
loss_names (Tuple[str]): Names of the loss components used in training.
|
23
|
+
|
24
|
+
Methods:
|
25
|
+
get_model: Retrieves a pose estimation model with specified configuration.
|
26
|
+
set_model_attributes: Sets keypoints shape attribute on the model.
|
27
|
+
get_validator: Creates a validator instance for model evaluation.
|
28
|
+
plot_training_samples: Visualizes training samples with keypoints.
|
29
|
+
plot_metrics: Generates and saves training/validation metric plots.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
33
|
+
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
34
|
+
>>> trainer = PoseTrainer(overrides=args)
|
35
|
+
>>> trainer.train()
|
23
36
|
"""
|
24
37
|
|
25
38
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|