dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
- tests/conftest.py +5 -8
- tests/test_cli.py +1 -8
- tests/test_python.py +1 -2
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +34 -49
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +244 -323
- ultralytics/data/base.py +12 -22
- ultralytics/data/build.py +47 -40
- ultralytics/data/converter.py +32 -42
- ultralytics/data/dataset.py +43 -71
- ultralytics/data/loaders.py +22 -34
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +27 -36
- ultralytics/engine/exporter.py +49 -116
- ultralytics/engine/model.py +144 -180
- ultralytics/engine/predictor.py +18 -29
- ultralytics/engine/results.py +165 -231
- ultralytics/engine/trainer.py +11 -19
- ultralytics/engine/tuner.py +13 -23
- ultralytics/engine/validator.py +6 -10
- ultralytics/hub/__init__.py +7 -12
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +3 -6
- ultralytics/models/fastsam/model.py +6 -8
- ultralytics/models/fastsam/predict.py +5 -10
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +2 -4
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -18
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +13 -20
- ultralytics/models/sam/amg.py +12 -18
- ultralytics/models/sam/build.py +6 -9
- ultralytics/models/sam/model.py +16 -23
- ultralytics/models/sam/modules/blocks.py +62 -84
- ultralytics/models/sam/modules/decoders.py +17 -24
- ultralytics/models/sam/modules/encoders.py +40 -56
- ultralytics/models/sam/modules/memory_attention.py +10 -16
- ultralytics/models/sam/modules/sam.py +41 -47
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +17 -27
- ultralytics/models/sam/modules/utils.py +31 -42
- ultralytics/models/sam/predict.py +172 -209
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/classify/predict.py +8 -11
- ultralytics/models/yolo/classify/train.py +8 -16
- ultralytics/models/yolo/classify/val.py +13 -20
- ultralytics/models/yolo/detect/predict.py +4 -8
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +38 -48
- ultralytics/models/yolo/model.py +35 -47
- ultralytics/models/yolo/obb/predict.py +5 -8
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +20 -28
- ultralytics/models/yolo/pose/predict.py +5 -8
- ultralytics/models/yolo/pose/train.py +4 -8
- ultralytics/models/yolo/pose/val.py +31 -39
- ultralytics/models/yolo/segment/predict.py +9 -14
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +16 -26
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -16
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/autobackend.py +10 -18
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +99 -185
- ultralytics/nn/modules/conv.py +45 -90
- ultralytics/nn/modules/head.py +44 -98
- ultralytics/nn/modules/transformer.py +44 -76
- ultralytics/nn/modules/utils.py +14 -19
- ultralytics/nn/tasks.py +86 -146
- ultralytics/nn/text_model.py +25 -40
- ultralytics/solutions/ai_gym.py +10 -16
- ultralytics/solutions/analytics.py +7 -10
- ultralytics/solutions/config.py +4 -5
- ultralytics/solutions/distance_calculation.py +9 -12
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +8 -12
- ultralytics/solutions/object_cropper.py +5 -8
- ultralytics/solutions/parking_management.py +12 -14
- ultralytics/solutions/queue_management.py +4 -6
- ultralytics/solutions/region_counter.py +7 -10
- ultralytics/solutions/security_alarm.py +14 -19
- ultralytics/solutions/similarity_search.py +7 -12
- ultralytics/solutions/solutions.py +31 -53
- ultralytics/solutions/speed_estimation.py +6 -9
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/basetrack.py +2 -4
- ultralytics/trackers/bot_sort.py +6 -11
- ultralytics/trackers/byte_tracker.py +10 -15
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +6 -12
- ultralytics/trackers/utils/kalman_filter.py +35 -43
- ultralytics/trackers/utils/matching.py +6 -10
- ultralytics/utils/__init__.py +61 -100
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +11 -13
- ultralytics/utils/benchmarks.py +25 -35
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +2 -4
- ultralytics/utils/callbacks/comet.py +30 -44
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +4 -6
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +4 -6
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +29 -56
- ultralytics/utils/cpu.py +1 -2
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +17 -27
- ultralytics/utils/errors.py +6 -8
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -239
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +11 -17
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +10 -15
- ultralytics/utils/git.py +5 -7
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +11 -15
- ultralytics/utils/loss.py +8 -14
- ultralytics/utils/metrics.py +98 -138
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +47 -74
- ultralytics/utils/patches.py +11 -18
- ultralytics/utils/plotting.py +29 -42
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +45 -73
- ultralytics/utils/tqdm.py +6 -8
- ultralytics/utils/triton.py +9 -12
- ultralytics/utils/tuner.py +1 -2
- dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
|
@@ -19,8 +19,7 @@ from ultralytics.utils.plotting import plot_images
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class DetectionValidator(BaseValidator):
|
|
22
|
-
"""
|
|
23
|
-
A class extending the BaseValidator class for validation based on a detection model.
|
|
22
|
+
"""A class extending the BaseValidator class for validation based on a detection model.
|
|
24
23
|
|
|
25
24
|
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
|
26
25
|
prediction processing, and visualization of results.
|
|
@@ -44,8 +43,7 @@ class DetectionValidator(BaseValidator):
|
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
45
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initialize detection validator with necessary variables and settings.
|
|
46
|
+
"""Initialize detection validator with necessary variables and settings.
|
|
49
47
|
|
|
50
48
|
Args:
|
|
51
49
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
|
@@ -63,8 +61,7 @@ class DetectionValidator(BaseValidator):
|
|
|
63
61
|
self.metrics = DetMetrics()
|
|
64
62
|
|
|
65
63
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
-
"""
|
|
67
|
-
Preprocess batch of images for YOLO validation.
|
|
64
|
+
"""Preprocess batch of images for YOLO validation.
|
|
68
65
|
|
|
69
66
|
Args:
|
|
70
67
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -79,8 +76,7 @@ class DetectionValidator(BaseValidator):
|
|
|
79
76
|
return batch
|
|
80
77
|
|
|
81
78
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
82
|
-
"""
|
|
83
|
-
Initialize evaluation metrics for YOLO detection validation.
|
|
79
|
+
"""Initialize evaluation metrics for YOLO detection validation.
|
|
84
80
|
|
|
85
81
|
Args:
|
|
86
82
|
model (torch.nn.Module): Model to validate.
|
|
@@ -107,15 +103,14 @@ class DetectionValidator(BaseValidator):
|
|
|
107
103
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
|
108
104
|
|
|
109
105
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
110
|
-
"""
|
|
111
|
-
Apply Non-maximum suppression to prediction outputs.
|
|
106
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
112
107
|
|
|
113
108
|
Args:
|
|
114
109
|
preds (torch.Tensor): Raw predictions from the model.
|
|
115
110
|
|
|
116
111
|
Returns:
|
|
117
|
-
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
|
118
|
-
'
|
|
112
|
+
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains 'bboxes', 'conf',
|
|
113
|
+
'cls', and 'extra' tensors.
|
|
119
114
|
"""
|
|
120
115
|
outputs = nms.non_max_suppression(
|
|
121
116
|
preds,
|
|
@@ -131,8 +126,7 @@ class DetectionValidator(BaseValidator):
|
|
|
131
126
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
|
132
127
|
|
|
133
128
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
134
|
-
"""
|
|
135
|
-
Prepare a batch of images and annotations for validation.
|
|
129
|
+
"""Prepare a batch of images and annotations for validation.
|
|
136
130
|
|
|
137
131
|
Args:
|
|
138
132
|
si (int): Batch index.
|
|
@@ -159,8 +153,7 @@ class DetectionValidator(BaseValidator):
|
|
|
159
153
|
}
|
|
160
154
|
|
|
161
155
|
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
162
|
-
"""
|
|
163
|
-
Prepare predictions for evaluation against ground truth.
|
|
156
|
+
"""Prepare predictions for evaluation against ground truth.
|
|
164
157
|
|
|
165
158
|
Args:
|
|
166
159
|
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
|
@@ -173,8 +166,7 @@ class DetectionValidator(BaseValidator):
|
|
|
173
166
|
return pred
|
|
174
167
|
|
|
175
168
|
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
|
176
|
-
"""
|
|
177
|
-
Update metrics with new predictions and ground truth.
|
|
169
|
+
"""Update metrics with new predictions and ground truth.
|
|
178
170
|
|
|
179
171
|
Args:
|
|
180
172
|
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
|
@@ -236,15 +228,21 @@ class DetectionValidator(BaseValidator):
|
|
|
236
228
|
for stats_dict in gathered_stats:
|
|
237
229
|
for key in merged_stats.keys():
|
|
238
230
|
merged_stats[key].extend(stats_dict[key])
|
|
231
|
+
gathered_jdict = [None] * dist.get_world_size()
|
|
232
|
+
dist.gather_object(self.jdict, gathered_jdict, dst=0)
|
|
233
|
+
self.jdict = []
|
|
234
|
+
for jdict in gathered_jdict:
|
|
235
|
+
self.jdict.extend(jdict)
|
|
239
236
|
self.metrics.stats = merged_stats
|
|
240
237
|
self.seen = len(self.dataloader.dataset) # total image count from dataset
|
|
241
238
|
elif RANK > 0:
|
|
242
239
|
dist.gather_object(self.metrics.stats, None, dst=0)
|
|
240
|
+
dist.gather_object(self.jdict, None, dst=0)
|
|
241
|
+
self.jdict = []
|
|
243
242
|
self.metrics.clear_stats()
|
|
244
243
|
|
|
245
244
|
def get_stats(self) -> dict[str, Any]:
|
|
246
|
-
"""
|
|
247
|
-
Calculate and return metrics statistics.
|
|
245
|
+
"""Calculate and return metrics statistics.
|
|
248
246
|
|
|
249
247
|
Returns:
|
|
250
248
|
(dict[str, Any]): Dictionary containing metrics results.
|
|
@@ -274,15 +272,15 @@ class DetectionValidator(BaseValidator):
|
|
|
274
272
|
)
|
|
275
273
|
|
|
276
274
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
277
|
-
"""
|
|
278
|
-
Return correct prediction matrix.
|
|
275
|
+
"""Return correct prediction matrix.
|
|
279
276
|
|
|
280
277
|
Args:
|
|
281
278
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
|
282
279
|
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
|
283
280
|
|
|
284
281
|
Returns:
|
|
285
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
282
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for
|
|
283
|
+
10 IoU levels.
|
|
286
284
|
"""
|
|
287
285
|
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
288
286
|
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
|
@@ -290,8 +288,7 @@ class DetectionValidator(BaseValidator):
|
|
|
290
288
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
291
289
|
|
|
292
290
|
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
|
293
|
-
"""
|
|
294
|
-
Build YOLO Dataset.
|
|
291
|
+
"""Build YOLO Dataset.
|
|
295
292
|
|
|
296
293
|
Args:
|
|
297
294
|
img_path (str): Path to the folder containing images.
|
|
@@ -304,8 +301,7 @@ class DetectionValidator(BaseValidator):
|
|
|
304
301
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
|
305
302
|
|
|
306
303
|
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
307
|
-
"""
|
|
308
|
-
Construct and return dataloader.
|
|
304
|
+
"""Construct and return dataloader.
|
|
309
305
|
|
|
310
306
|
Args:
|
|
311
307
|
dataset_path (str): Path to the dataset.
|
|
@@ -326,8 +322,7 @@ class DetectionValidator(BaseValidator):
|
|
|
326
322
|
)
|
|
327
323
|
|
|
328
324
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
329
|
-
"""
|
|
330
|
-
Plot validation image samples.
|
|
325
|
+
"""Plot validation image samples.
|
|
331
326
|
|
|
332
327
|
Args:
|
|
333
328
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -344,8 +339,7 @@ class DetectionValidator(BaseValidator):
|
|
|
344
339
|
def plot_predictions(
|
|
345
340
|
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
|
346
341
|
) -> None:
|
|
347
|
-
"""
|
|
348
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
342
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
349
343
|
|
|
350
344
|
Args:
|
|
351
345
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -371,8 +365,7 @@ class DetectionValidator(BaseValidator):
|
|
|
371
365
|
) # pred
|
|
372
366
|
|
|
373
367
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
374
|
-
"""
|
|
375
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
368
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
376
369
|
|
|
377
370
|
Args:
|
|
378
371
|
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
|
@@ -390,12 +383,11 @@ class DetectionValidator(BaseValidator):
|
|
|
390
383
|
).save_txt(file, save_conf=save_conf)
|
|
391
384
|
|
|
392
385
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
393
|
-
"""
|
|
394
|
-
Serialize YOLO predictions to COCO json format.
|
|
386
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
395
387
|
|
|
396
388
|
Args:
|
|
397
|
-
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
398
|
-
|
|
389
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
390
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
399
391
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
400
392
|
|
|
401
393
|
Examples:
|
|
@@ -436,8 +428,7 @@ class DetectionValidator(BaseValidator):
|
|
|
436
428
|
}
|
|
437
429
|
|
|
438
430
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
439
|
-
"""
|
|
440
|
-
Evaluate YOLO output in JSON format and return performance statistics.
|
|
431
|
+
"""Evaluate YOLO output in JSON format and return performance statistics.
|
|
441
432
|
|
|
442
433
|
Args:
|
|
443
434
|
stats (dict[str, Any]): Current statistics dictionary.
|
|
@@ -461,21 +452,20 @@ class DetectionValidator(BaseValidator):
|
|
|
461
452
|
iou_types: str | list[str] = "bbox",
|
|
462
453
|
suffix: str | list[str] = "Box",
|
|
463
454
|
) -> dict[str, Any]:
|
|
464
|
-
"""
|
|
465
|
-
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
455
|
+
"""Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
|
466
456
|
|
|
467
|
-
Performs evaluation using the faster-coco-eval library to compute mAP metrics
|
|
468
|
-
|
|
469
|
-
|
|
457
|
+
Performs evaluation using the faster-coco-eval library to compute mAP metrics for object detection. Updates the
|
|
458
|
+
provided stats dictionary with computed metrics including mAP50, mAP50-95, and LVIS-specific metrics if
|
|
459
|
+
applicable.
|
|
470
460
|
|
|
471
461
|
Args:
|
|
472
462
|
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
|
473
463
|
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
|
|
474
464
|
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
|
|
475
|
-
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
|
|
476
|
-
|
|
477
|
-
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
|
|
478
|
-
|
|
465
|
+
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings. Common
|
|
466
|
+
values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
|
467
|
+
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond to
|
|
468
|
+
iou_types if multiple types provided. Defaults to "Box".
|
|
479
469
|
|
|
480
470
|
Returns:
|
|
481
471
|
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
ultralytics/models/yolo/model.py
CHANGED
|
@@ -24,8 +24,7 @@ from ultralytics.utils import ROOT, YAML
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class YOLO(Model):
|
|
27
|
-
"""
|
|
28
|
-
YOLO (You Only Look Once) object detection model.
|
|
27
|
+
"""YOLO (You Only Look Once) object detection model.
|
|
29
28
|
|
|
30
29
|
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
|
31
30
|
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
|
@@ -52,16 +51,15 @@ class YOLO(Model):
|
|
|
52
51
|
"""
|
|
53
52
|
|
|
54
53
|
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
|
55
|
-
"""
|
|
56
|
-
Initialize a YOLO model.
|
|
54
|
+
"""Initialize a YOLO model.
|
|
57
55
|
|
|
58
|
-
This constructor initializes a YOLO model, automatically switching to specialized model types
|
|
59
|
-
|
|
56
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
|
|
57
|
+
YOLOE) based on the model filename.
|
|
60
58
|
|
|
61
59
|
Args:
|
|
62
60
|
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
|
63
|
-
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
|
64
|
-
|
|
61
|
+
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
|
|
62
|
+
to auto-detection based on model.
|
|
65
63
|
verbose (bool): Display model info on load.
|
|
66
64
|
|
|
67
65
|
Examples:
|
|
@@ -126,12 +124,11 @@ class YOLO(Model):
|
|
|
126
124
|
|
|
127
125
|
|
|
128
126
|
class YOLOWorld(Model):
|
|
129
|
-
"""
|
|
130
|
-
YOLO-World object detection model.
|
|
127
|
+
"""YOLO-World object detection model.
|
|
131
128
|
|
|
132
|
-
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
|
133
|
-
|
|
134
|
-
|
|
129
|
+
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
|
|
130
|
+
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
|
|
131
|
+
detection.
|
|
135
132
|
|
|
136
133
|
Attributes:
|
|
137
134
|
model: The loaded YOLO-World model instance.
|
|
@@ -152,11 +149,10 @@ class YOLOWorld(Model):
|
|
|
152
149
|
"""
|
|
153
150
|
|
|
154
151
|
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
|
155
|
-
"""
|
|
156
|
-
Initialize YOLOv8-World model with a pre-trained model file.
|
|
152
|
+
"""Initialize YOLOv8-World model with a pre-trained model file.
|
|
157
153
|
|
|
158
|
-
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
|
159
|
-
|
|
154
|
+
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
|
|
155
|
+
class names.
|
|
160
156
|
|
|
161
157
|
Args:
|
|
162
158
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -181,8 +177,7 @@ class YOLOWorld(Model):
|
|
|
181
177
|
}
|
|
182
178
|
|
|
183
179
|
def set_classes(self, classes: list[str]) -> None:
|
|
184
|
-
"""
|
|
185
|
-
Set the model's class names for detection.
|
|
180
|
+
"""Set the model's class names for detection.
|
|
186
181
|
|
|
187
182
|
Args:
|
|
188
183
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -200,11 +195,10 @@ class YOLOWorld(Model):
|
|
|
200
195
|
|
|
201
196
|
|
|
202
197
|
class YOLOE(Model):
|
|
203
|
-
"""
|
|
204
|
-
YOLOE object detection and segmentation model.
|
|
198
|
+
"""YOLOE object detection and segmentation model.
|
|
205
199
|
|
|
206
|
-
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
|
207
|
-
|
|
200
|
+
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
|
|
201
|
+
performance and additional features like visual and text positional embeddings.
|
|
208
202
|
|
|
209
203
|
Attributes:
|
|
210
204
|
model: The loaded YOLOE model instance.
|
|
@@ -235,8 +229,7 @@ class YOLOE(Model):
|
|
|
235
229
|
"""
|
|
236
230
|
|
|
237
231
|
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
|
238
|
-
"""
|
|
239
|
-
Initialize YOLOE model with a pre-trained model file.
|
|
232
|
+
"""Initialize YOLOE model with a pre-trained model file.
|
|
240
233
|
|
|
241
234
|
Args:
|
|
242
235
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -269,11 +262,10 @@ class YOLOE(Model):
|
|
|
269
262
|
return self.model.get_text_pe(texts)
|
|
270
263
|
|
|
271
264
|
def get_visual_pe(self, img, visual):
|
|
272
|
-
"""
|
|
273
|
-
Get visual positional embeddings for the given image and visual features.
|
|
265
|
+
"""Get visual positional embeddings for the given image and visual features.
|
|
274
266
|
|
|
275
|
-
This method extracts positional embeddings from visual features based on the input image. It requires
|
|
276
|
-
|
|
267
|
+
This method extracts positional embeddings from visual features based on the input image. It requires that the
|
|
268
|
+
model is an instance of YOLOEModel.
|
|
277
269
|
|
|
278
270
|
Args:
|
|
279
271
|
img (torch.Tensor): Input image tensor.
|
|
@@ -292,11 +284,10 @@ class YOLOE(Model):
|
|
|
292
284
|
return self.model.get_visual_pe(img, visual)
|
|
293
285
|
|
|
294
286
|
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
|
295
|
-
"""
|
|
296
|
-
Set vocabulary and class names for the YOLOE model.
|
|
287
|
+
"""Set vocabulary and class names for the YOLOE model.
|
|
297
288
|
|
|
298
|
-
This method configures the vocabulary and class names used by the model for text processing and
|
|
299
|
-
|
|
289
|
+
This method configures the vocabulary and class names used by the model for text processing and classification
|
|
290
|
+
tasks. The model must be an instance of YOLOEModel.
|
|
300
291
|
|
|
301
292
|
Args:
|
|
302
293
|
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
|
@@ -318,8 +309,7 @@ class YOLOE(Model):
|
|
|
318
309
|
return self.model.get_vocab(names)
|
|
319
310
|
|
|
320
311
|
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
|
321
|
-
"""
|
|
322
|
-
Set the model's class names and embeddings for detection.
|
|
312
|
+
"""Set the model's class names and embeddings for detection.
|
|
323
313
|
|
|
324
314
|
Args:
|
|
325
315
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -344,8 +334,7 @@ class YOLOE(Model):
|
|
|
344
334
|
refer_data: str | None = None,
|
|
345
335
|
**kwargs,
|
|
346
336
|
):
|
|
347
|
-
"""
|
|
348
|
-
Validate the model using text or visual prompts.
|
|
337
|
+
"""Validate the model using text or visual prompts.
|
|
349
338
|
|
|
350
339
|
Args:
|
|
351
340
|
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
|
@@ -373,19 +362,18 @@ class YOLOE(Model):
|
|
|
373
362
|
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
|
374
363
|
**kwargs,
|
|
375
364
|
):
|
|
376
|
-
"""
|
|
377
|
-
Run prediction on images, videos, directories, streams, etc.
|
|
365
|
+
"""Run prediction on images, videos, directories, streams, etc.
|
|
378
366
|
|
|
379
367
|
Args:
|
|
380
|
-
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
|
381
|
-
|
|
382
|
-
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
|
383
|
-
|
|
384
|
-
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
|
385
|
-
|
|
368
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
|
|
369
|
+
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
|
370
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
|
|
371
|
+
are computed.
|
|
372
|
+
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
|
373
|
+
and 'cls' keys when non-empty.
|
|
386
374
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
|
387
|
-
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
|
388
|
-
|
|
375
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
|
|
376
|
+
based on the task.
|
|
389
377
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
|
390
378
|
|
|
391
379
|
Returns:
|
|
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class OBBPredictor(DetectionPredictor):
|
|
11
|
-
"""
|
|
12
|
-
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
11
|
+
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
13
12
|
|
|
14
13
|
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
|
15
14
|
bounding boxes.
|
|
@@ -27,8 +26,7 @@ class OBBPredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize OBBPredictor with optional model and data configuration overrides.
|
|
29
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
|
32
30
|
|
|
33
31
|
Args:
|
|
34
32
|
cfg (dict, optional): Default configuration for the predictor.
|
|
@@ -45,12 +43,11 @@ class OBBPredictor(DetectionPredictor):
|
|
|
45
43
|
self.args.task = "obb"
|
|
46
44
|
|
|
47
45
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
48
|
-
"""
|
|
49
|
-
Construct the result object from the prediction.
|
|
46
|
+
"""Construct the result object from the prediction.
|
|
50
47
|
|
|
51
48
|
Args:
|
|
52
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
|
53
|
-
|
|
49
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
|
50
|
+
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
|
54
51
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
|
55
52
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
56
53
|
img_path (str): The path to the original image.
|
|
@@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
17
16
|
|
|
18
|
-
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
|
19
|
-
|
|
17
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
|
18
|
+
objects at arbitrary angles rather than just axis-aligned rectangles.
|
|
20
19
|
|
|
21
20
|
Attributes:
|
|
22
|
-
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
|
23
|
-
|
|
21
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
|
22
|
+
dfl_loss.
|
|
24
23
|
|
|
25
24
|
Methods:
|
|
26
25
|
get_model: Return OBBModel initialized with specified config and weights.
|
|
@@ -34,14 +33,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
35
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
|
37
|
-
"""
|
|
38
|
-
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
36
|
+
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
|
-
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
|
42
|
-
|
|
43
|
-
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
|
44
|
-
|
|
39
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
|
40
|
+
configuration.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
|
42
|
+
take precedence over those in cfg.
|
|
45
43
|
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
|
46
44
|
"""
|
|
47
45
|
if overrides is None:
|
|
@@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
52
50
|
def get_model(
|
|
53
51
|
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
|
54
52
|
) -> OBBModel:
|
|
55
|
-
"""
|
|
56
|
-
Return OBBModel initialized with specified config and weights.
|
|
53
|
+
"""Return OBBModel initialized with specified config and weights.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
|
@@ -15,8 +15,7 @@ from ultralytics.utils.nms import TorchNMS
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class OBBValidator(DetectionValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
18
|
+
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
20
19
|
|
|
21
20
|
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
|
22
21
|
satellite imagery where objects can appear at various orientations.
|
|
@@ -44,11 +43,10 @@ class OBBValidator(DetectionValidator):
|
|
|
44
43
|
"""
|
|
45
44
|
|
|
46
45
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
46
|
+
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
49
47
|
|
|
50
|
-
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
|
51
|
-
|
|
48
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
|
49
|
+
extends the DetectionValidator class and configures it specifically for the OBB task.
|
|
52
50
|
|
|
53
51
|
Args:
|
|
54
52
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
|
@@ -61,8 +59,7 @@ class OBBValidator(DetectionValidator):
|
|
|
61
59
|
self.metrics = OBBMetrics()
|
|
62
60
|
|
|
63
61
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize evaluation metrics for YOLO obb validation.
|
|
62
|
+
"""Initialize evaluation metrics for YOLO obb validation.
|
|
66
63
|
|
|
67
64
|
Args:
|
|
68
65
|
model (torch.nn.Module): Model to validate.
|
|
@@ -73,19 +70,18 @@ class OBBValidator(DetectionValidator):
|
|
|
73
70
|
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
|
74
71
|
|
|
75
72
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
76
|
-
"""
|
|
77
|
-
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
73
|
+
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
78
74
|
|
|
79
75
|
Args:
|
|
80
76
|
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
|
81
77
|
class labels and bounding boxes.
|
|
82
|
-
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
|
83
|
-
|
|
78
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
|
79
|
+
labels and bounding boxes.
|
|
84
80
|
|
|
85
81
|
Returns:
|
|
86
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
|
87
|
-
|
|
88
|
-
|
|
82
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
|
83
|
+
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
|
84
|
+
predictions compared to the ground truth.
|
|
89
85
|
|
|
90
86
|
Examples:
|
|
91
87
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
|
@@ -99,7 +95,8 @@ class OBBValidator(DetectionValidator):
|
|
|
99
95
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
100
96
|
|
|
101
97
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
102
|
-
"""
|
|
98
|
+
"""Postprocess OBB predictions.
|
|
99
|
+
|
|
103
100
|
Args:
|
|
104
101
|
preds (torch.Tensor): Raw predictions from the model.
|
|
105
102
|
|
|
@@ -112,8 +109,7 @@ class OBBValidator(DetectionValidator):
|
|
|
112
109
|
return preds
|
|
113
110
|
|
|
114
111
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
-
"""
|
|
116
|
-
Prepare batch data for OBB validation with proper scaling and formatting.
|
|
112
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
|
117
113
|
|
|
118
114
|
Args:
|
|
119
115
|
si (int): Batch index to process.
|
|
@@ -146,8 +142,7 @@ class OBBValidator(DetectionValidator):
|
|
|
146
142
|
}
|
|
147
143
|
|
|
148
144
|
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
145
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
151
146
|
|
|
152
147
|
Args:
|
|
153
148
|
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
|
@@ -166,12 +161,11 @@ class OBBValidator(DetectionValidator):
|
|
|
166
161
|
super().plot_predictions(batch, preds, ni) # plot bboxes
|
|
167
162
|
|
|
168
163
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
164
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
171
165
|
|
|
172
166
|
Args:
|
|
173
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
174
|
-
|
|
167
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
168
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
175
169
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
176
170
|
|
|
177
171
|
Notes:
|
|
@@ -197,8 +191,7 @@ class OBBValidator(DetectionValidator):
|
|
|
197
191
|
)
|
|
198
192
|
|
|
199
193
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
200
|
-
"""
|
|
201
|
-
Save YOLO OBB detections to a text file in normalized coordinates.
|
|
194
|
+
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
|
202
195
|
|
|
203
196
|
Args:
|
|
204
197
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
|
@@ -233,8 +226,7 @@ class OBBValidator(DetectionValidator):
|
|
|
233
226
|
}
|
|
234
227
|
|
|
235
228
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
236
|
-
"""
|
|
237
|
-
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
229
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
238
230
|
|
|
239
231
|
Args:
|
|
240
232
|
stats (dict[str, Any]): Performance statistics dictionary.
|
|
@@ -5,8 +5,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class PosePredictor(DetectionPredictor):
|
|
8
|
-
"""
|
|
9
|
-
A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
8
|
+
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
10
9
|
|
|
11
10
|
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
|
12
11
|
capabilities inherited from DetectionPredictor.
|
|
@@ -27,11 +26,10 @@ class PosePredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize PosePredictor for pose estimation tasks.
|
|
29
|
+
"""Initialize PosePredictor for pose estimation tasks.
|
|
32
30
|
|
|
33
|
-
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
|
34
|
-
|
|
31
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
|
32
|
+
for Apple MPS.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
35
|
cfg (Any): Configuration for the predictor.
|
|
@@ -54,8 +52,7 @@ class PosePredictor(DetectionPredictor):
|
|
|
54
52
|
)
|
|
55
53
|
|
|
56
54
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
57
|
-
"""
|
|
58
|
-
Construct the result object from the prediction, including keypoints.
|
|
55
|
+
"""Construct the result object from the prediction, including keypoints.
|
|
59
56
|
|
|
60
57
|
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
|
61
58
|
result object.
|