dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,23 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from copy import copy
|
|
6
|
+
from pathlib import Path
|
|
4
7
|
|
|
5
8
|
from ultralytics.models import yolo
|
|
6
9
|
from ultralytics.nn.tasks import SegmentationModel
|
|
7
10
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
8
|
-
from ultralytics.utils.plotting import plot_images, plot_results
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
12
|
-
"""
|
|
13
|
-
A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
15
|
|
|
15
16
|
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
16
17
|
functionality including model initialization, validation, and visualization.
|
|
17
18
|
|
|
18
19
|
Attributes:
|
|
19
|
-
loss_names (
|
|
20
|
+
loss_names (tuple[str]): Names of the loss components used during training.
|
|
20
21
|
|
|
21
22
|
Examples:
|
|
22
23
|
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
|
@@ -25,36 +26,25 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
25
26
|
>>> trainer.train()
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
29
|
-
"""
|
|
30
|
-
Initialize a SegmentationTrainer object.
|
|
31
|
-
|
|
32
|
-
This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
|
|
33
|
-
functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
|
|
29
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
34
31
|
|
|
35
32
|
Args:
|
|
36
|
-
cfg (dict): Configuration dictionary with default training settings.
|
|
33
|
+
cfg (dict): Configuration dictionary with default training settings.
|
|
37
34
|
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
|
38
35
|
_callbacks (list, optional): List of callback functions to be executed during training.
|
|
39
|
-
|
|
40
|
-
Examples:
|
|
41
|
-
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
|
42
|
-
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
|
43
|
-
>>> trainer = SegmentationTrainer(overrides=args)
|
|
44
|
-
>>> trainer.train()
|
|
45
36
|
"""
|
|
46
37
|
if overrides is None:
|
|
47
38
|
overrides = {}
|
|
48
39
|
overrides["task"] = "segment"
|
|
49
40
|
super().__init__(cfg, overrides, _callbacks)
|
|
50
41
|
|
|
51
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
52
|
-
"""
|
|
53
|
-
Initialize and return a SegmentationModel with specified configuration and weights.
|
|
42
|
+
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
54
44
|
|
|
55
45
|
Args:
|
|
56
|
-
cfg (dict | str
|
|
57
|
-
weights (str | Path
|
|
46
|
+
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
47
|
+
weights (str | Path, optional): Path to pretrained weights file.
|
|
58
48
|
verbose (bool): Whether to display model information during initialization.
|
|
59
49
|
|
|
60
50
|
Returns:
|
|
@@ -77,47 +67,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
77
67
|
return yolo.segment.SegmentationValidator(
|
|
78
68
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
79
69
|
)
|
|
80
|
-
|
|
81
|
-
def plot_training_samples(self, batch, ni):
|
|
82
|
-
"""
|
|
83
|
-
Plot training sample images with labels, bounding boxes, and masks.
|
|
84
|
-
|
|
85
|
-
This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
|
|
86
|
-
and segmentation masks, saving the result to a file for inspection and debugging.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
batch (dict): Dictionary containing batch data with the following keys:
|
|
90
|
-
'img': Images tensor
|
|
91
|
-
'batch_idx': Batch indices for each box
|
|
92
|
-
'cls': Class labels tensor (squeezed to remove last dimension)
|
|
93
|
-
'bboxes': Bounding box coordinates tensor
|
|
94
|
-
'masks': Segmentation masks tensor
|
|
95
|
-
'im_file': List of image file paths
|
|
96
|
-
ni (int): Current training iteration number, used for naming the output file.
|
|
97
|
-
|
|
98
|
-
Examples:
|
|
99
|
-
>>> trainer = SegmentationTrainer()
|
|
100
|
-
>>> batch = {
|
|
101
|
-
... "img": torch.rand(16, 3, 640, 640),
|
|
102
|
-
... "batch_idx": torch.zeros(16),
|
|
103
|
-
... "cls": torch.randint(0, 80, (16, 1)),
|
|
104
|
-
... "bboxes": torch.rand(16, 4),
|
|
105
|
-
... "masks": torch.rand(16, 640, 640),
|
|
106
|
-
... "im_file": ["image1.jpg", "image2.jpg"],
|
|
107
|
-
... }
|
|
108
|
-
>>> trainer.plot_training_samples(batch, ni=5)
|
|
109
|
-
"""
|
|
110
|
-
plot_images(
|
|
111
|
-
batch["img"],
|
|
112
|
-
batch["batch_idx"],
|
|
113
|
-
batch["cls"].squeeze(-1),
|
|
114
|
-
batch["bboxes"],
|
|
115
|
-
masks=batch["masks"],
|
|
116
|
-
paths=batch["im_file"],
|
|
117
|
-
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
118
|
-
on_plot=self.on_plot,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
def plot_metrics(self):
|
|
122
|
-
"""Plots training/val metrics."""
|
|
123
|
-
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from multiprocessing.pool import ThreadPool
|
|
4
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
5
8
|
|
|
6
9
|
import numpy as np
|
|
7
10
|
import torch
|
|
@@ -10,16 +13,14 @@ import torch.nn.functional as F
|
|
|
10
13
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
11
14
|
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
|
12
15
|
from ultralytics.utils.checks import check_requirements
|
|
13
|
-
from ultralytics.utils.metrics import SegmentMetrics,
|
|
14
|
-
from ultralytics.utils.plotting import output_to_target, plot_images
|
|
16
|
+
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class SegmentationValidator(DetectionValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
20
|
+
"""A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
20
21
|
|
|
21
|
-
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
|
22
|
-
|
|
22
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
|
|
23
|
+
compute metrics such as mAP for both detection and segmentation tasks.
|
|
23
24
|
|
|
24
25
|
Attributes:
|
|
25
26
|
plot_masks (list): List to store masks for plotting.
|
|
@@ -35,45 +36,46 @@ class SegmentationValidator(DetectionValidator):
|
|
|
35
36
|
>>> validator()
|
|
36
37
|
"""
|
|
37
38
|
|
|
38
|
-
def __init__(self, dataloader=None, save_dir=None,
|
|
39
|
-
"""
|
|
40
|
-
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
39
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
40
|
+
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
43
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
|
44
44
|
save_dir (Path, optional): Directory to save results.
|
|
45
|
-
pbar (Any, optional): Progress bar for displaying progress.
|
|
46
45
|
args (namespace, optional): Arguments for the validator.
|
|
47
46
|
_callbacks (list, optional): List of callback functions.
|
|
48
47
|
"""
|
|
49
|
-
super().__init__(dataloader, save_dir,
|
|
50
|
-
self.plot_masks = None
|
|
48
|
+
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
51
49
|
self.process = None
|
|
52
50
|
self.args.task = "segment"
|
|
53
|
-
self.metrics = SegmentMetrics(
|
|
51
|
+
self.metrics = SegmentMetrics()
|
|
52
|
+
|
|
53
|
+
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
54
|
+
"""Preprocess batch of images for YOLO segmentation validation.
|
|
54
55
|
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
Args:
|
|
57
|
+
batch (dict[str, Any]): Batch containing images and annotations.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(dict[str, Any]): Preprocessed batch.
|
|
61
|
+
"""
|
|
57
62
|
batch = super().preprocess(batch)
|
|
58
|
-
batch["masks"] = batch["masks"].
|
|
63
|
+
batch["masks"] = batch["masks"].float()
|
|
59
64
|
return batch
|
|
60
65
|
|
|
61
|
-
def init_metrics(self, model):
|
|
62
|
-
"""
|
|
63
|
-
Initialize metrics and select mask processing function based on save_json flag.
|
|
66
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
67
|
+
"""Initialize metrics and select mask processing function based on save_json flag.
|
|
64
68
|
|
|
65
69
|
Args:
|
|
66
70
|
model (torch.nn.Module): Model to validate.
|
|
67
71
|
"""
|
|
68
72
|
super().init_metrics(model)
|
|
69
|
-
self.plot_masks = []
|
|
70
73
|
if self.args.save_json:
|
|
71
|
-
check_requirements("
|
|
72
|
-
#
|
|
74
|
+
check_requirements("faster-coco-eval>=1.6.7")
|
|
75
|
+
# More accurate vs faster
|
|
73
76
|
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
|
|
74
|
-
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
|
75
77
|
|
|
76
|
-
def get_desc(self):
|
|
78
|
+
def get_desc(self) -> str:
|
|
77
79
|
"""Return a formatted description of evaluation metrics."""
|
|
78
80
|
return ("%22s" + "%11s" * 10) % (
|
|
79
81
|
"Class",
|
|
@@ -89,238 +91,108 @@ class SegmentationValidator(DetectionValidator):
|
|
|
89
91
|
"mAP50-95)",
|
|
90
92
|
)
|
|
91
93
|
|
|
92
|
-
def postprocess(self, preds):
|
|
93
|
-
"""
|
|
94
|
-
Post-process YOLO predictions and return output detections with proto.
|
|
94
|
+
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
|
95
|
+
"""Post-process YOLO predictions and return output detections with proto.
|
|
95
96
|
|
|
96
97
|
Args:
|
|
97
|
-
preds (list): Raw predictions from the model.
|
|
98
|
+
preds (list[torch.Tensor]): Raw predictions from the model.
|
|
98
99
|
|
|
99
100
|
Returns:
|
|
100
|
-
|
|
101
|
-
proto (torch.Tensor): Prototype masks for segmentation.
|
|
101
|
+
list[dict[str, torch.Tensor]]: Processed detection predictions with masks.
|
|
102
102
|
"""
|
|
103
|
-
p = super().postprocess(preds[0])
|
|
104
103
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
|
105
|
-
|
|
104
|
+
preds = super().postprocess(preds[0])
|
|
105
|
+
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
|
|
106
|
+
for i, pred in enumerate(preds):
|
|
107
|
+
coefficient = pred.pop("extra")
|
|
108
|
+
pred["masks"] = (
|
|
109
|
+
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
|
|
110
|
+
if coefficient.shape[0]
|
|
111
|
+
else torch.zeros(
|
|
112
|
+
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
|
|
113
|
+
dtype=torch.uint8,
|
|
114
|
+
device=pred["bboxes"].device,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
return preds
|
|
106
118
|
|
|
107
|
-
def _prepare_batch(self, si, batch):
|
|
108
|
-
"""
|
|
109
|
-
Prepare a batch for training or inference by processing images and targets.
|
|
119
|
+
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
120
|
+
"""Prepare a batch for training or inference by processing images and targets.
|
|
110
121
|
|
|
111
122
|
Args:
|
|
112
123
|
si (int): Batch index.
|
|
113
|
-
batch (dict): Batch data containing images and
|
|
124
|
+
batch (dict[str, Any]): Batch data containing images and annotations.
|
|
114
125
|
|
|
115
126
|
Returns:
|
|
116
|
-
(dict): Prepared batch with processed
|
|
127
|
+
(dict[str, Any]): Prepared batch with processed annotations.
|
|
117
128
|
"""
|
|
118
129
|
prepared_batch = super()._prepare_batch(si, batch)
|
|
119
|
-
|
|
120
|
-
|
|
130
|
+
nl = prepared_batch["cls"].shape[0]
|
|
131
|
+
if self.args.overlap_mask:
|
|
132
|
+
masks = batch["masks"][si]
|
|
133
|
+
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
|
|
134
|
+
masks = (masks == index).float()
|
|
135
|
+
else:
|
|
136
|
+
masks = batch["masks"][batch["batch_idx"] == si]
|
|
137
|
+
if nl:
|
|
138
|
+
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
|
|
139
|
+
if masks.shape[1:] != mask_size:
|
|
140
|
+
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
|
|
141
|
+
masks = masks.gt_(0.5)
|
|
142
|
+
prepared_batch["masks"] = masks
|
|
121
143
|
return prepared_batch
|
|
122
144
|
|
|
123
|
-
def
|
|
124
|
-
"""
|
|
125
|
-
Prepare predictions for evaluation by processing bounding boxes and masks.
|
|
145
|
+
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
146
|
+
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
126
147
|
|
|
127
148
|
Args:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
proto (torch.Tensor): Prototype masks for segmentation.
|
|
149
|
+
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
|
150
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
|
|
131
151
|
|
|
132
152
|
Returns:
|
|
133
|
-
|
|
134
|
-
pred_masks (torch.Tensor): Processed mask predictions.
|
|
135
|
-
"""
|
|
136
|
-
predn = super()._prepare_pred(pred, pbatch)
|
|
137
|
-
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
|
|
138
|
-
return predn, pred_masks
|
|
139
|
-
|
|
140
|
-
def update_metrics(self, preds, batch):
|
|
141
|
-
"""
|
|
142
|
-
Update metrics with the current batch predictions and targets.
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
preds (list): Predictions from the model.
|
|
146
|
-
batch (dict): Batch data containing images and targets.
|
|
147
|
-
"""
|
|
148
|
-
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
|
149
|
-
self.seen += 1
|
|
150
|
-
npr = len(pred)
|
|
151
|
-
stat = dict(
|
|
152
|
-
conf=torch.zeros(0, device=self.device),
|
|
153
|
-
pred_cls=torch.zeros(0, device=self.device),
|
|
154
|
-
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
155
|
-
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
156
|
-
)
|
|
157
|
-
pbatch = self._prepare_batch(si, batch)
|
|
158
|
-
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
|
159
|
-
nl = len(cls)
|
|
160
|
-
stat["target_cls"] = cls
|
|
161
|
-
stat["target_img"] = cls.unique()
|
|
162
|
-
if npr == 0:
|
|
163
|
-
if nl:
|
|
164
|
-
for k in self.stats.keys():
|
|
165
|
-
self.stats[k].append(stat[k])
|
|
166
|
-
if self.args.plots:
|
|
167
|
-
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
|
168
|
-
continue
|
|
169
|
-
|
|
170
|
-
# Masks
|
|
171
|
-
gt_masks = pbatch.pop("masks")
|
|
172
|
-
# Predictions
|
|
173
|
-
if self.args.single_cls:
|
|
174
|
-
pred[:, 5] = 0
|
|
175
|
-
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
|
|
176
|
-
stat["conf"] = predn[:, 4]
|
|
177
|
-
stat["pred_cls"] = predn[:, 5]
|
|
178
|
-
|
|
179
|
-
# Evaluate
|
|
180
|
-
if nl:
|
|
181
|
-
stat["tp"] = self._process_batch(predn, bbox, cls)
|
|
182
|
-
stat["tp_m"] = self._process_batch(
|
|
183
|
-
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
|
|
184
|
-
)
|
|
185
|
-
if self.args.plots:
|
|
186
|
-
self.confusion_matrix.process_batch(predn, bbox, cls)
|
|
187
|
-
|
|
188
|
-
for k in self.stats.keys():
|
|
189
|
-
self.stats[k].append(stat[k])
|
|
190
|
-
|
|
191
|
-
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
|
192
|
-
if self.args.plots and self.batch_i < 3:
|
|
193
|
-
self.plot_masks.append(pred_masks[:50].cpu()) # Limit plotted items for speed
|
|
194
|
-
if pred_masks.shape[0] > 50:
|
|
195
|
-
LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
|
|
196
|
-
|
|
197
|
-
# Save
|
|
198
|
-
if self.args.save_json:
|
|
199
|
-
self.pred_to_json(
|
|
200
|
-
predn,
|
|
201
|
-
batch["im_file"][si],
|
|
202
|
-
ops.scale_image(
|
|
203
|
-
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
|
204
|
-
pbatch["ori_shape"],
|
|
205
|
-
ratio_pad=batch["ratio_pad"][si],
|
|
206
|
-
),
|
|
207
|
-
)
|
|
208
|
-
if self.args.save_txt:
|
|
209
|
-
self.save_one_txt(
|
|
210
|
-
predn,
|
|
211
|
-
pred_masks,
|
|
212
|
-
self.args.save_conf,
|
|
213
|
-
pbatch["ori_shape"],
|
|
214
|
-
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
def finalize_metrics(self, *args, **kwargs):
|
|
218
|
-
"""
|
|
219
|
-
Finalize evaluation metrics by setting the speed attribute in the metrics object.
|
|
220
|
-
|
|
221
|
-
This method is called at the end of validation to set the processing speed for the metrics calculations.
|
|
222
|
-
It transfers the validator's speed measurement to the metrics object for reporting.
|
|
153
|
+
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
|
223
154
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
self.metrics.speed = self.speed
|
|
229
|
-
self.metrics.confusion_matrix = self.confusion_matrix
|
|
230
|
-
|
|
231
|
-
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
|
232
|
-
"""
|
|
233
|
-
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
|
|
237
|
-
associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
|
|
238
|
-
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
|
|
239
|
-
Each row is of the format [x1, y1, x2, y2].
|
|
240
|
-
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
|
|
241
|
-
pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
|
|
242
|
-
match the ground truth masks.
|
|
243
|
-
gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
|
|
244
|
-
overlap (bool): Flag indicating if overlapping masks should be considered.
|
|
245
|
-
masks (bool): Flag indicating if the batch contains mask data.
|
|
246
|
-
|
|
247
|
-
Returns:
|
|
248
|
-
(torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
|
|
155
|
+
Examples:
|
|
156
|
+
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
157
|
+
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
158
|
+
>>> correct_preds = validator._process_batch(preds, batch)
|
|
249
159
|
|
|
250
|
-
|
|
160
|
+
Notes:
|
|
251
161
|
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
252
162
|
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
253
|
-
|
|
254
|
-
Examples:
|
|
255
|
-
>>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
|
|
256
|
-
>>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
|
|
257
|
-
>>> gt_cls = torch.tensor([1, 0])
|
|
258
|
-
>>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
|
|
259
|
-
"""
|
|
260
|
-
if masks:
|
|
261
|
-
if overlap:
|
|
262
|
-
nl = len(gt_cls)
|
|
263
|
-
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
|
264
|
-
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
|
265
|
-
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
|
266
|
-
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
|
267
|
-
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
|
268
|
-
gt_masks = gt_masks.gt_(0.5)
|
|
269
|
-
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
|
270
|
-
else: # boxes
|
|
271
|
-
iou = box_iou(gt_bboxes, detections[:, :4])
|
|
272
|
-
|
|
273
|
-
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
|
274
|
-
|
|
275
|
-
def plot_val_samples(self, batch, ni):
|
|
276
|
-
"""
|
|
277
|
-
Plot validation samples with bounding box labels and masks.
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
batch (dict): Batch data containing images and targets.
|
|
281
|
-
ni (int): Batch index.
|
|
282
163
|
"""
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
on_plot=self.on_plot,
|
|
293
|
-
)
|
|
164
|
+
tp = super()._process_batch(preds, batch)
|
|
165
|
+
gt_cls = batch["cls"]
|
|
166
|
+
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
167
|
+
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
168
|
+
else:
|
|
169
|
+
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1).float()) # float, uint8
|
|
170
|
+
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
|
171
|
+
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
|
172
|
+
return tp
|
|
294
173
|
|
|
295
|
-
def plot_predictions(self, batch, preds, ni):
|
|
296
|
-
"""
|
|
297
|
-
Plot batch predictions with masks and bounding boxes.
|
|
174
|
+
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
175
|
+
"""Plot batch predictions with masks and bounding boxes.
|
|
298
176
|
|
|
299
177
|
Args:
|
|
300
|
-
batch (dict): Batch
|
|
301
|
-
preds (list):
|
|
178
|
+
batch (dict[str, Any]): Batch containing images and annotations.
|
|
179
|
+
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
|
302
180
|
ni (int): Batch index.
|
|
303
181
|
"""
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
names=self.names,
|
|
311
|
-
on_plot=self.on_plot,
|
|
312
|
-
) # pred
|
|
313
|
-
self.plot_masks.clear()
|
|
182
|
+
for p in preds:
|
|
183
|
+
masks = p["masks"]
|
|
184
|
+
if masks.shape[0] > self.args.max_det:
|
|
185
|
+
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
|
|
186
|
+
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
|
|
187
|
+
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
|
314
188
|
|
|
315
|
-
def save_one_txt(self, predn
|
|
316
|
-
"""
|
|
317
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
189
|
+
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
190
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
318
191
|
|
|
319
192
|
Args:
|
|
320
|
-
predn (torch.Tensor): Predictions in the format
|
|
321
|
-
pred_masks (torch.Tensor): Predicted masks.
|
|
193
|
+
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
|
322
194
|
save_conf (bool): Whether to save confidence scores.
|
|
323
|
-
shape (tuple):
|
|
195
|
+
shape (tuple[int, int]): Shape of the original image.
|
|
324
196
|
file (Path): File path to save the detections.
|
|
325
197
|
"""
|
|
326
198
|
from ultralytics.engine.results import Results
|
|
@@ -329,23 +201,18 @@ class SegmentationValidator(DetectionValidator):
|
|
|
329
201
|
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
|
330
202
|
path=None,
|
|
331
203
|
names=self.names,
|
|
332
|
-
boxes=predn[
|
|
333
|
-
masks=
|
|
204
|
+
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
|
205
|
+
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
|
|
334
206
|
).save_txt(file, save_conf=save_conf)
|
|
335
207
|
|
|
336
|
-
def pred_to_json(self, predn,
|
|
337
|
-
"""
|
|
338
|
-
Save one JSON result for COCO evaluation.
|
|
208
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
209
|
+
"""Save one JSON result for COCO evaluation.
|
|
339
210
|
|
|
340
211
|
Args:
|
|
341
|
-
predn (torch.Tensor): Predictions
|
|
342
|
-
|
|
343
|
-
pred_masks (numpy.ndarray): Predicted masks.
|
|
344
|
-
|
|
345
|
-
Examples:
|
|
346
|
-
>>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
|
212
|
+
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
|
213
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
347
214
|
"""
|
|
348
|
-
from
|
|
215
|
+
from faster_coco_eval.core.mask import encode
|
|
349
216
|
|
|
350
217
|
def single_encode(x):
|
|
351
218
|
"""Encode predicted masks as RLE and append results to jdict."""
|
|
@@ -353,76 +220,30 @@ class SegmentationValidator(DetectionValidator):
|
|
|
353
220
|
rle["counts"] = rle["counts"].decode("utf-8")
|
|
354
221
|
return rle
|
|
355
222
|
|
|
356
|
-
|
|
357
|
-
image_id = int(stem) if stem.isnumeric() else stem
|
|
358
|
-
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
|
359
|
-
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
|
360
|
-
pred_masks = np.transpose(pred_masks, (2, 0, 1))
|
|
223
|
+
pred_masks = np.transpose(predn["masks"], (2, 0, 1))
|
|
361
224
|
with ThreadPool(NUM_THREADS) as pool:
|
|
362
225
|
rles = pool.map(single_encode, pred_masks)
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
|
388
|
-
for x in anno_json, pred_json:
|
|
389
|
-
assert x.is_file(), f"{x} file not found"
|
|
390
|
-
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
|
391
|
-
if self.is_coco:
|
|
392
|
-
from pycocotools.coco import COCO # noqa
|
|
393
|
-
from pycocotools.cocoeval import COCOeval # noqa
|
|
394
|
-
|
|
395
|
-
anno = COCO(str(anno_json)) # init annotations api
|
|
396
|
-
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
397
|
-
vals = [COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]
|
|
398
|
-
else:
|
|
399
|
-
from lvis import LVIS, LVISEval
|
|
400
|
-
|
|
401
|
-
anno = LVIS(str(anno_json))
|
|
402
|
-
pred = anno._load_json(str(pred_json))
|
|
403
|
-
vals = [LVISEval(anno, pred, "bbox"), LVISEval(anno, pred, "segm")]
|
|
404
|
-
|
|
405
|
-
for i, eval in enumerate(vals):
|
|
406
|
-
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
|
407
|
-
eval.evaluate()
|
|
408
|
-
eval.accumulate()
|
|
409
|
-
eval.summarize()
|
|
410
|
-
if self.is_lvis:
|
|
411
|
-
eval.print_results()
|
|
412
|
-
idx = i * 4 + 2
|
|
413
|
-
# update mAP50-95 and mAP50
|
|
414
|
-
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = (
|
|
415
|
-
eval.stats[:2] if self.is_coco else [eval.results["AP"], eval.results["AP50"]]
|
|
416
|
-
)
|
|
417
|
-
if self.is_lvis:
|
|
418
|
-
tag = "B" if i == 0 else "M"
|
|
419
|
-
stats[f"metrics/APr({tag})"] = eval.results["APr"]
|
|
420
|
-
stats[f"metrics/APc({tag})"] = eval.results["APc"]
|
|
421
|
-
stats[f"metrics/APf({tag})"] = eval.results["APf"]
|
|
422
|
-
|
|
423
|
-
if self.is_lvis:
|
|
424
|
-
stats["fitness"] = stats["metrics/mAP50-95(B)"]
|
|
425
|
-
|
|
426
|
-
except Exception as e:
|
|
427
|
-
LOGGER.warning(f"{pkg} unable to run: {e}")
|
|
428
|
-
return stats
|
|
226
|
+
super().pred_to_json(predn, pbatch)
|
|
227
|
+
for i, r in enumerate(rles):
|
|
228
|
+
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
|
229
|
+
|
|
230
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
231
|
+
"""Scales predictions to the original image size."""
|
|
232
|
+
return {
|
|
233
|
+
**super().scale_preds(predn, pbatch),
|
|
234
|
+
"masks": ops.scale_image(
|
|
235
|
+
torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
|
|
236
|
+
pbatch["ori_shape"],
|
|
237
|
+
ratio_pad=pbatch["ratio_pad"],
|
|
238
|
+
),
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
242
|
+
"""Return COCO-style instance segmentation evaluation metrics."""
|
|
243
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
244
|
+
anno_json = (
|
|
245
|
+
self.data["path"]
|
|
246
|
+
/ "annotations"
|
|
247
|
+
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
|
248
|
+
) # annotations
|
|
249
|
+
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "segm"], suffix=["Box", "Mask"])
|