ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
ultralytics/models/utils/loss.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|
6
6
|
|
7
7
|
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
8
8
|
from ultralytics.utils.metrics import bbox_iou
|
9
|
+
|
9
10
|
from .ops import HungarianMatcher
|
10
11
|
|
11
12
|
|
@@ -33,15 +34,19 @@ class DETRLoss(nn.Module):
|
|
33
34
|
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
|
34
35
|
):
|
35
36
|
"""
|
36
|
-
DETR loss function.
|
37
|
+
Initialize DETR loss function with customizable components and gains.
|
38
|
+
|
39
|
+
Uses default loss_gain if not provided. Initializes HungarianMatcher with
|
40
|
+
preset cost gains. Supports auxiliary losses and various loss types.
|
37
41
|
|
38
42
|
Args:
|
39
|
-
nc (int):
|
40
|
-
loss_gain (dict):
|
41
|
-
aux_loss (bool):
|
42
|
-
|
43
|
-
|
44
|
-
|
43
|
+
nc (int): Number of classes.
|
44
|
+
loss_gain (dict): Coefficients for different loss components.
|
45
|
+
aux_loss (bool): Use auxiliary losses from each decoder layer.
|
46
|
+
use_fl (bool): Use FocalLoss.
|
47
|
+
use_vfl (bool): Use VarifocalLoss.
|
48
|
+
use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
|
49
|
+
uni_match_ind (int): Index of fixed layer for uni_match.
|
45
50
|
"""
|
46
51
|
super().__init__()
|
47
52
|
|
@@ -81,9 +86,7 @@ class DETRLoss(nn.Module):
|
|
81
86
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
82
87
|
|
83
88
|
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
84
|
-
"""
|
85
|
-
boxes.
|
86
|
-
"""
|
89
|
+
"""Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
|
87
90
|
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
88
91
|
name_bbox = f"loss_bbox{postfix}"
|
89
92
|
name_giou = f"loss_giou{postfix}"
|
@@ -240,23 +243,32 @@ class DETRLoss(nn.Module):
|
|
240
243
|
if len(gt_bboxes):
|
241
244
|
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
242
245
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
return loss
|
246
|
+
return {
|
247
|
+
**self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
|
248
|
+
**self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
|
249
|
+
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
|
250
|
+
}
|
249
251
|
|
250
252
|
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
|
251
253
|
"""
|
254
|
+
Calculate loss for predicted bounding boxes and scores.
|
255
|
+
|
252
256
|
Args:
|
253
|
-
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
254
|
-
pred_scores (torch.Tensor): [l, b, query, num_classes]
|
255
|
-
batch (dict):
|
256
|
-
|
257
|
-
|
258
|
-
gt_groups (List
|
259
|
-
postfix (str):
|
257
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
|
258
|
+
pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
|
259
|
+
batch (dict): Batch information containing:
|
260
|
+
cls (torch.Tensor): Ground truth classes, shape [num_gts].
|
261
|
+
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
|
262
|
+
gt_groups (List[int]): Number of ground truths for each image in the batch.
|
263
|
+
postfix (str): Postfix for loss names.
|
264
|
+
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
(dict): Computed losses, including main and auxiliary (if enabled).
|
268
|
+
|
269
|
+
Note:
|
270
|
+
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
271
|
+
self.aux_loss is True.
|
260
272
|
"""
|
261
273
|
self.device = pred_bboxes.device
|
262
274
|
match_indices = kwargs.get("match_indices", None)
|
ultralytics/models/utils/ops.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
@@ -32,9 +32,7 @@ class HungarianMatcher(nn.Module):
|
|
32
32
|
"""
|
33
33
|
|
34
34
|
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
35
|
-
"""Initializes HungarianMatcher
|
36
|
-
gamma factors.
|
37
|
-
"""
|
35
|
+
"""Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
|
38
36
|
super().__init__()
|
39
37
|
if cost_gain is None:
|
40
38
|
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
@@ -70,7 +68,6 @@ class HungarianMatcher(nn.Module):
|
|
70
68
|
For each batch element, it holds:
|
71
69
|
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
72
70
|
"""
|
73
|
-
|
74
71
|
bs, nq, nc = pred_scores.shape
|
75
72
|
|
76
73
|
if sum(gt_groups) == 0:
|
@@ -133,7 +130,7 @@ class HungarianMatcher(nn.Module):
|
|
133
130
|
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
134
131
|
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
135
132
|
#
|
136
|
-
# with torch.
|
133
|
+
# with torch.amp.autocast("cuda", enabled=False):
|
137
134
|
# # binary cross entropy cost
|
138
135
|
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
139
136
|
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
@@ -175,7 +172,6 @@ def get_cdn_group(
|
|
175
172
|
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
|
176
173
|
is less than or equal to 0, the function returns None for all elements in the tuple.
|
177
174
|
"""
|
178
|
-
|
179
175
|
if (not training) or num_dn <= 0:
|
180
176
|
return None, None, None, None
|
181
177
|
gt_groups = batch["gt_groups"]
|
@@ -1,7 +1,7 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from ultralytics.models.yolo import classify, detect, obb, pose, segment
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
4
4
|
|
5
5
|
from .model import YOLO, YOLOWorld
|
6
6
|
|
7
|
-
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
|
7
|
+
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import cv2
|
4
4
|
import torch
|
@@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
|
|
21
21
|
from ultralytics.utils import ASSETS
|
22
22
|
from ultralytics.models.yolo.classify import ClassificationPredictor
|
23
23
|
|
24
|
-
args = dict(model=
|
24
|
+
args = dict(model="yolov8n-cls.pt", source=ASSETS)
|
25
25
|
predictor = ClassificationPredictor(overrides=args)
|
26
26
|
predictor.predict_cli()
|
27
27
|
```
|
@@ -53,9 +53,8 @@ class ClassificationPredictor(BasePredictor):
|
|
53
53
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
54
54
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
55
55
|
|
56
|
-
|
57
|
-
|
58
|
-
orig_img =
|
59
|
-
img_path
|
60
|
-
|
61
|
-
return results
|
56
|
+
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
57
|
+
return [
|
58
|
+
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
59
|
+
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
60
|
+
]
|
@@ -1,13 +1,14 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import copy
|
2
4
|
|
3
5
|
import torch
|
4
|
-
import torchvision
|
5
6
|
|
6
7
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
7
8
|
from ultralytics.engine.trainer import BaseTrainer
|
8
9
|
from ultralytics.models import yolo
|
9
|
-
from ultralytics.nn.tasks import ClassificationModel
|
10
|
-
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
10
|
+
from ultralytics.nn.tasks import ClassificationModel
|
11
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
11
12
|
from ultralytics.utils.plotting import plot_images, plot_results
|
12
13
|
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
13
14
|
|
@@ -23,7 +24,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
23
24
|
```python
|
24
25
|
from ultralytics.models.yolo.classify import ClassificationTrainer
|
25
26
|
|
26
|
-
args = dict(model=
|
27
|
+
args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
|
27
28
|
trainer = ClassificationTrainer(overrides=args)
|
28
29
|
trainer.train()
|
29
30
|
```
|
@@ -59,23 +60,16 @@ class ClassificationTrainer(BaseTrainer):
|
|
59
60
|
|
60
61
|
def setup_model(self):
|
61
62
|
"""Load, create or download model for any task."""
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
for p in self.model.parameters():
|
70
|
-
p.requires_grad = True # for training
|
71
|
-
elif model.split(".")[-1] in ("yaml", "yml"):
|
72
|
-
self.model = self.get_model(cfg=model)
|
73
|
-
elif model in torchvision.models.__dict__:
|
74
|
-
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
63
|
+
import torchvision # scope for faster 'import ultralytics'
|
64
|
+
|
65
|
+
if str(self.model) in torchvision.models.__dict__:
|
66
|
+
self.model = torchvision.models.__dict__[self.model](
|
67
|
+
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
68
|
+
)
|
69
|
+
ckpt = None
|
75
70
|
else:
|
76
|
-
|
71
|
+
ckpt = super().setup_model()
|
77
72
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
78
|
-
|
79
73
|
return ckpt
|
80
74
|
|
81
75
|
def build_dataset(self, img_path, mode="train", batch=None):
|
@@ -115,7 +109,9 @@ class ClassificationTrainer(BaseTrainer):
|
|
115
109
|
def get_validator(self):
|
116
110
|
"""Returns an instance of ClassificationValidator for validation."""
|
117
111
|
self.loss_names = ["loss"]
|
118
|
-
return yolo.classify.ClassificationValidator(
|
112
|
+
return yolo.classify.ClassificationValidator(
|
113
|
+
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
114
|
+
)
|
119
115
|
|
120
116
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
121
117
|
"""
|
@@ -145,7 +141,6 @@ class ClassificationTrainer(BaseTrainer):
|
|
145
141
|
self.metrics = self.validator(model=f)
|
146
142
|
self.metrics.pop("fitness", None)
|
147
143
|
self.run_callbacks("on_fit_epoch_end")
|
148
|
-
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
149
144
|
|
150
145
|
def plot_training_samples(self, batch, ni):
|
151
146
|
"""Plots training samples with their annotations."""
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -20,7 +20,7 @@ class ClassificationValidator(BaseValidator):
|
|
20
20
|
```python
|
21
21
|
from ultralytics.models.yolo.classify import ClassificationValidator
|
22
22
|
|
23
|
-
args = dict(model=
|
23
|
+
args = dict(model="yolov8n-cls.pt", data="imagenet10")
|
24
24
|
validator = ClassificationValidator(args=args)
|
25
25
|
validator()
|
26
26
|
```
|
@@ -56,8 +56,8 @@ class ClassificationValidator(BaseValidator):
|
|
56
56
|
def update_metrics(self, preds, batch):
|
57
57
|
"""Updates running metrics with model predictions and batch targets."""
|
58
58
|
n5 = min(len(self.names), 5)
|
59
|
-
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
60
|
-
self.targets.append(batch["cls"])
|
59
|
+
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
60
|
+
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
61
61
|
|
62
62
|
def finalize_metrics(self, *args, **kwargs):
|
63
63
|
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
@@ -71,6 +71,10 @@ class ClassificationValidator(BaseValidator):
|
|
71
71
|
self.metrics.confusion_matrix = self.confusion_matrix
|
72
72
|
self.metrics.save_dir = self.save_dir
|
73
73
|
|
74
|
+
def postprocess(self, preds):
|
75
|
+
"""Preprocesses the classification predictions."""
|
76
|
+
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
77
|
+
|
74
78
|
def get_stats(self):
|
75
79
|
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
76
80
|
self.metrics.process(self.targets, self.pred)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from ultralytics.engine.predictor import BasePredictor
|
4
4
|
from ultralytics.engine.results import Results
|
@@ -14,7 +14,7 @@ class DetectionPredictor(BasePredictor):
|
|
14
14
|
from ultralytics.utils import ASSETS
|
15
15
|
from ultralytics.models.yolo.detect import DetectionPredictor
|
16
16
|
|
17
|
-
args = dict(model=
|
17
|
+
args = dict(model="yolo11n.pt", source=ASSETS)
|
18
18
|
predictor = DetectionPredictor(overrides=args)
|
19
19
|
predictor.predict_cli()
|
20
20
|
```
|
@@ -35,9 +35,7 @@ class DetectionPredictor(BasePredictor):
|
|
35
35
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
36
36
|
|
37
37
|
results = []
|
38
|
-
for
|
39
|
-
orig_img = orig_imgs[i]
|
38
|
+
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
40
39
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
41
|
-
img_path = self.batch[0][i]
|
42
40
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
43
41
|
return results
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import math
|
4
4
|
import random
|
@@ -24,7 +24,7 @@ class DetectionTrainer(BaseTrainer):
|
|
24
24
|
```python
|
25
25
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
26
26
|
|
27
|
-
args = dict(model=
|
27
|
+
args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
|
28
28
|
trainer = DetectionTrainer(overrides=args)
|
29
29
|
trainer.train()
|
30
30
|
```
|
@@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
|
|
44
44
|
|
45
45
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
46
46
|
"""Construct and return dataloader."""
|
47
|
-
assert mode in
|
47
|
+
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
48
48
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
49
49
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
50
50
|
shuffle = mode == "train"
|
@@ -60,7 +60,7 @@ class DetectionTrainer(BaseTrainer):
|
|
60
60
|
if self.args.multi_scale:
|
61
61
|
imgs = batch["img"]
|
62
62
|
sz = (
|
63
|
-
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
|
63
|
+
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
64
64
|
// self.stride
|
65
65
|
* self.stride
|
66
66
|
) # size
|
@@ -141,3 +141,10 @@ class DetectionTrainer(BaseTrainer):
|
|
141
141
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
142
142
|
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
143
143
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
144
|
+
|
145
|
+
def auto_batch(self):
|
146
|
+
"""Get batch size by calculating memory occupation of model."""
|
147
|
+
train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
|
148
|
+
# 4 for mosaic augmentation
|
149
|
+
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
|
150
|
+
return super().auto_batch(max_num_obj)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
@@ -22,7 +22,7 @@ class DetectionValidator(BaseValidator):
|
|
22
22
|
```python
|
23
23
|
from ultralytics.models.yolo.detect import DetectionValidator
|
24
24
|
|
25
|
-
args = dict(model=
|
25
|
+
args = dict(model="yolo11n.pt", data="coco8.yaml")
|
26
26
|
validator = DetectionValidator(args=args)
|
27
27
|
validator()
|
28
28
|
```
|
@@ -32,13 +32,20 @@ class DetectionValidator(BaseValidator):
|
|
32
32
|
"""Initialize detection model with necessary variables and settings."""
|
33
33
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
34
34
|
self.nt_per_class = None
|
35
|
+
self.nt_per_image = None
|
35
36
|
self.is_coco = False
|
37
|
+
self.is_lvis = False
|
36
38
|
self.class_map = None
|
37
39
|
self.args.task = "detect"
|
38
40
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
39
41
|
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
40
42
|
self.niou = self.iouv.numel()
|
41
43
|
self.lb = [] # for autolabelling
|
44
|
+
if self.args.save_hybrid:
|
45
|
+
LOGGER.warning(
|
46
|
+
"WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n"
|
47
|
+
"WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n"
|
48
|
+
)
|
42
49
|
|
43
50
|
def preprocess(self, batch):
|
44
51
|
"""Preprocesses batch of images for YOLO training."""
|
@@ -51,23 +58,24 @@ class DetectionValidator(BaseValidator):
|
|
51
58
|
height, width = batch["img"].shape[2:]
|
52
59
|
nb = len(batch["img"])
|
53
60
|
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
|
54
|
-
self.lb =
|
55
|
-
[
|
56
|
-
|
57
|
-
|
58
|
-
]
|
59
|
-
if self.args.save_hybrid
|
60
|
-
else []
|
61
|
-
) # for autolabelling
|
61
|
+
self.lb = [
|
62
|
+
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
|
63
|
+
for i in range(nb)
|
64
|
+
]
|
62
65
|
|
63
66
|
return batch
|
64
67
|
|
65
68
|
def init_metrics(self, model):
|
66
69
|
"""Initialize evaluation metrics for YOLO."""
|
67
70
|
val = self.data.get(self.args.split, "") # validation path
|
68
|
-
self.is_coco =
|
69
|
-
|
70
|
-
|
71
|
+
self.is_coco = (
|
72
|
+
isinstance(val, str)
|
73
|
+
and "coco" in val
|
74
|
+
and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
|
75
|
+
) # is COCO
|
76
|
+
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
77
|
+
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
|
78
|
+
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
|
71
79
|
self.names = model.names
|
72
80
|
self.nc = len(model.names)
|
73
81
|
self.metrics.names = self.names
|
@@ -75,7 +83,7 @@ class DetectionValidator(BaseValidator):
|
|
75
83
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
|
76
84
|
self.seen = 0
|
77
85
|
self.jdict = []
|
78
|
-
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
|
86
|
+
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
79
87
|
|
80
88
|
def get_desc(self):
|
81
89
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
@@ -89,7 +97,7 @@ class DetectionValidator(BaseValidator):
|
|
89
97
|
self.args.iou,
|
90
98
|
labels=self.lb,
|
91
99
|
multi_label=True,
|
92
|
-
agnostic=self.args.single_cls,
|
100
|
+
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
93
101
|
max_det=self.args.max_det,
|
94
102
|
)
|
95
103
|
|
@@ -104,7 +112,7 @@ class DetectionValidator(BaseValidator):
|
|
104
112
|
if len(cls):
|
105
113
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
106
114
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
107
|
-
return
|
115
|
+
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
108
116
|
|
109
117
|
def _prepare_pred(self, pred, pbatch):
|
110
118
|
"""Prepares a batch of images and annotations for validation."""
|
@@ -128,6 +136,7 @@ class DetectionValidator(BaseValidator):
|
|
128
136
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
129
137
|
nl = len(cls)
|
130
138
|
stat["target_cls"] = cls
|
139
|
+
stat["target_img"] = cls.unique()
|
131
140
|
if npr == 0:
|
132
141
|
if nl:
|
133
142
|
for k in self.stats.keys():
|
@@ -146,8 +155,8 @@ class DetectionValidator(BaseValidator):
|
|
146
155
|
# Evaluate
|
147
156
|
if nl:
|
148
157
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
149
|
-
|
150
|
-
|
158
|
+
if self.args.plots:
|
159
|
+
self.confusion_matrix.process_batch(predn, bbox, cls)
|
151
160
|
for k in self.stats.keys():
|
152
161
|
self.stats[k].append(stat[k])
|
153
162
|
|
@@ -155,8 +164,12 @@ class DetectionValidator(BaseValidator):
|
|
155
164
|
if self.args.save_json:
|
156
165
|
self.pred_to_json(predn, batch["im_file"][si])
|
157
166
|
if self.args.save_txt:
|
158
|
-
|
159
|
-
|
167
|
+
self.save_one_txt(
|
168
|
+
predn,
|
169
|
+
self.args.save_conf,
|
170
|
+
pbatch["ori_shape"],
|
171
|
+
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
172
|
+
)
|
160
173
|
|
161
174
|
def finalize_metrics(self, *args, **kwargs):
|
162
175
|
"""Set final values for metrics speed and confusion matrix."""
|
@@ -166,11 +179,11 @@ class DetectionValidator(BaseValidator):
|
|
166
179
|
def get_stats(self):
|
167
180
|
"""Returns metrics statistics and results dictionary."""
|
168
181
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
182
|
+
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
|
183
|
+
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
|
184
|
+
stats.pop("target_img", None)
|
169
185
|
if len(stats) and stats["tp"].any():
|
170
186
|
self.metrics.process(**stats)
|
171
|
-
self.nt_per_class = np.bincount(
|
172
|
-
stats["target_cls"].astype(int), minlength=self.nc
|
173
|
-
) # number of targets per class
|
174
187
|
return self.metrics.results_dict
|
175
188
|
|
176
189
|
def print_results(self):
|
@@ -183,7 +196,9 @@ class DetectionValidator(BaseValidator):
|
|
183
196
|
# Print results per class
|
184
197
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
185
198
|
for i, c in enumerate(self.metrics.ap_class_index):
|
186
|
-
LOGGER.info(
|
199
|
+
LOGGER.info(
|
200
|
+
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
|
201
|
+
)
|
187
202
|
|
188
203
|
if self.args.plots:
|
189
204
|
for normalize in True, False:
|
@@ -196,13 +211,18 @@ class DetectionValidator(BaseValidator):
|
|
196
211
|
Return correct prediction matrix.
|
197
212
|
|
198
213
|
Args:
|
199
|
-
detections (torch.Tensor): Tensor of shape
|
200
|
-
|
201
|
-
|
202
|
-
|
214
|
+
detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
|
215
|
+
(x1, y1, x2, y2, conf, class).
|
216
|
+
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
|
217
|
+
bounding box is of the format: (x1, y1, x2, y2).
|
218
|
+
gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
|
203
219
|
|
204
220
|
Returns:
|
205
|
-
(torch.Tensor): Correct prediction matrix of shape
|
221
|
+
(torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
222
|
+
|
223
|
+
Note:
|
224
|
+
The function does not return any value directly usable for metrics calculation. Instead, it provides an
|
225
|
+
intermediate representation used for evaluating predictions against ground truth.
|
206
226
|
"""
|
207
227
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
208
228
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
@@ -249,12 +269,14 @@ class DetectionValidator(BaseValidator):
|
|
249
269
|
|
250
270
|
def save_one_txt(self, predn, save_conf, shape, file):
|
251
271
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
272
|
+
from ultralytics.engine.results import Results
|
273
|
+
|
274
|
+
Results(
|
275
|
+
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
276
|
+
path=None,
|
277
|
+
names=self.names,
|
278
|
+
boxes=predn[:, :6],
|
279
|
+
).save_txt(file, save_conf=save_conf)
|
258
280
|
|
259
281
|
def pred_to_json(self, predn, filename):
|
260
282
|
"""Serialize YOLO predictions to COCO json format."""
|
@@ -274,26 +296,42 @@ class DetectionValidator(BaseValidator):
|
|
274
296
|
|
275
297
|
def eval_json(self, stats):
|
276
298
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
277
|
-
if self.args.save_json and self.is_coco and len(self.jdict):
|
278
|
-
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
299
|
+
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
279
300
|
pred_json = self.save_dir / "predictions.json" # predictions
|
280
|
-
|
301
|
+
anno_json = (
|
302
|
+
self.data["path"]
|
303
|
+
/ "annotations"
|
304
|
+
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
305
|
+
) # annotations
|
306
|
+
pkg = "pycocotools" if self.is_coco else "lvis"
|
307
|
+
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
281
308
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
282
|
-
|
283
|
-
from pycocotools.coco import COCO # noqa
|
284
|
-
from pycocotools.cocoeval import COCOeval # noqa
|
285
|
-
|
286
|
-
for x in anno_json, pred_json:
|
309
|
+
for x in pred_json, anno_json:
|
287
310
|
assert x.is_file(), f"{x} file not found"
|
288
|
-
|
289
|
-
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
290
|
-
eval = COCOeval(anno, pred, "bbox")
|
311
|
+
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
291
312
|
if self.is_coco:
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
313
|
+
from pycocotools.coco import COCO # noqa
|
314
|
+
from pycocotools.cocoeval import COCOeval # noqa
|
315
|
+
|
316
|
+
anno = COCO(str(anno_json)) # init annotations api
|
317
|
+
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
318
|
+
val = COCOeval(anno, pred, "bbox")
|
319
|
+
else:
|
320
|
+
from lvis import LVIS, LVISEval
|
321
|
+
|
322
|
+
anno = LVIS(str(anno_json)) # init annotations api
|
323
|
+
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
|
324
|
+
val = LVISEval(anno, pred, "bbox")
|
325
|
+
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
326
|
+
val.evaluate()
|
327
|
+
val.accumulate()
|
328
|
+
val.summarize()
|
329
|
+
if self.is_lvis:
|
330
|
+
val.print_results() # explicitly call print_results
|
331
|
+
# update mAP50-95 and mAP50
|
332
|
+
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
333
|
+
val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
|
334
|
+
)
|
297
335
|
except Exception as e:
|
298
|
-
LOGGER.warning(f"
|
336
|
+
LOGGER.warning(f"{pkg} unable to run: {e}")
|
299
337
|
return stats
|