dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -6,20 +6,20 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
9
10
|
|
|
10
11
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
11
12
|
from ultralytics.engine.validator import BaseValidator
|
|
12
|
-
from ultralytics.utils import LOGGER
|
|
13
|
+
from ultralytics.utils import LOGGER, RANK
|
|
13
14
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
|
14
15
|
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class ClassificationValidator(BaseValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the BaseValidator class for validation based on a classification model.
|
|
19
|
+
"""A class extending the BaseValidator class for validation based on a classification model.
|
|
20
20
|
|
|
21
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
|
22
|
-
|
|
21
|
+
This validator handles the validation process for classification models, including metrics calculation, confusion
|
|
22
|
+
matrix generation, and visualization of results.
|
|
23
23
|
|
|
24
24
|
Attributes:
|
|
25
25
|
targets (list[torch.Tensor]): Ground truth class labels.
|
|
@@ -54,20 +54,13 @@ class ClassificationValidator(BaseValidator):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
57
|
-
"""
|
|
58
|
-
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
57
|
+
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
60
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
62
61
|
save_dir (str | Path, optional): Directory to save results.
|
|
63
62
|
args (dict, optional): Arguments containing model and validation configuration.
|
|
64
63
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
65
|
-
|
|
66
|
-
Examples:
|
|
67
|
-
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
68
|
-
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
|
69
|
-
>>> validator = ClassificationValidator(args=args)
|
|
70
|
-
>>> validator()
|
|
71
64
|
"""
|
|
72
65
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
73
66
|
self.targets = None
|
|
@@ -89,14 +82,13 @@ class ClassificationValidator(BaseValidator):
|
|
|
89
82
|
|
|
90
83
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
91
84
|
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
|
92
|
-
batch["img"] = batch["img"].to(self.device, non_blocking=
|
|
85
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
93
86
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
|
94
|
-
batch["cls"] = batch["cls"].to(self.device, non_blocking=
|
|
87
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
95
88
|
return batch
|
|
96
89
|
|
|
97
90
|
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
|
98
|
-
"""
|
|
99
|
-
Update running metrics with model predictions and batch targets.
|
|
91
|
+
"""Update running metrics with model predictions and batch targets.
|
|
100
92
|
|
|
101
93
|
Args:
|
|
102
94
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
|
@@ -111,12 +103,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
111
103
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
|
112
104
|
|
|
113
105
|
def finalize_metrics(self) -> None:
|
|
114
|
-
"""
|
|
115
|
-
Finalize metrics including confusion matrix and processing speed.
|
|
116
|
-
|
|
117
|
-
Notes:
|
|
118
|
-
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
119
|
-
optionally plots it, and updates the metrics object with speed information.
|
|
106
|
+
"""Finalize metrics including confusion matrix and processing speed.
|
|
120
107
|
|
|
121
108
|
Examples:
|
|
122
109
|
>>> validator = ClassificationValidator()
|
|
@@ -124,6 +111,10 @@ class ClassificationValidator(BaseValidator):
|
|
|
124
111
|
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
|
125
112
|
>>> validator.finalize_metrics()
|
|
126
113
|
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
|
114
|
+
|
|
115
|
+
Notes:
|
|
116
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
117
|
+
optionally plots it, and updates the metrics object with speed information.
|
|
127
118
|
"""
|
|
128
119
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
|
129
120
|
if self.args.plots:
|
|
@@ -142,13 +133,25 @@ class ClassificationValidator(BaseValidator):
|
|
|
142
133
|
self.metrics.process(self.targets, self.pred)
|
|
143
134
|
return self.metrics.results_dict
|
|
144
135
|
|
|
136
|
+
def gather_stats(self) -> None:
|
|
137
|
+
"""Gather stats from all GPUs."""
|
|
138
|
+
if RANK == 0:
|
|
139
|
+
gathered_preds = [None] * dist.get_world_size()
|
|
140
|
+
gathered_targets = [None] * dist.get_world_size()
|
|
141
|
+
dist.gather_object(self.pred, gathered_preds, dst=0)
|
|
142
|
+
dist.gather_object(self.targets, gathered_targets, dst=0)
|
|
143
|
+
self.pred = [pred for rank in gathered_preds for pred in rank]
|
|
144
|
+
self.targets = [targets for rank in gathered_targets for targets in rank]
|
|
145
|
+
elif RANK > 0:
|
|
146
|
+
dist.gather_object(self.pred, None, dst=0)
|
|
147
|
+
dist.gather_object(self.targets, None, dst=0)
|
|
148
|
+
|
|
145
149
|
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
|
146
150
|
"""Create a ClassificationDataset instance for validation."""
|
|
147
151
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
148
152
|
|
|
149
153
|
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
150
|
-
"""
|
|
151
|
-
Build and return a data loader for classification validation.
|
|
154
|
+
"""Build and return a data loader for classification validation.
|
|
152
155
|
|
|
153
156
|
Args:
|
|
154
157
|
dataset_path (str | Path): Path to the dataset directory.
|
|
@@ -166,8 +169,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
166
169
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
|
167
170
|
|
|
168
171
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Plot validation image samples with their ground truth labels.
|
|
172
|
+
"""Plot validation image samples with their ground truth labels.
|
|
171
173
|
|
|
172
174
|
Args:
|
|
173
175
|
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
@@ -178,7 +180,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
178
180
|
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
|
179
181
|
>>> validator.plot_val_samples(batch, 0)
|
|
180
182
|
"""
|
|
181
|
-
batch["batch_idx"] = torch.arange(
|
|
183
|
+
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
182
184
|
plot_images(
|
|
183
185
|
labels=batch,
|
|
184
186
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
@@ -187,8 +189,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
187
189
|
)
|
|
188
190
|
|
|
189
191
|
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
|
190
|
-
"""
|
|
191
|
-
Plot images with their predicted class labels and save the visualization.
|
|
192
|
+
"""Plot images with their predicted class labels and save the visualization.
|
|
192
193
|
|
|
193
194
|
Args:
|
|
194
195
|
batch (dict[str, Any]): Batch data containing images and other information.
|
|
@@ -203,8 +204,9 @@ class ClassificationValidator(BaseValidator):
|
|
|
203
204
|
"""
|
|
204
205
|
batched_preds = dict(
|
|
205
206
|
img=batch["img"],
|
|
206
|
-
batch_idx=torch.arange(
|
|
207
|
+
batch_idx=torch.arange(batch["img"].shape[0]),
|
|
207
208
|
cls=torch.argmax(preds, dim=1),
|
|
209
|
+
conf=torch.amax(preds, dim=1),
|
|
208
210
|
)
|
|
209
211
|
plot_images(
|
|
210
212
|
batched_preds,
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class DetectionPredictor(BasePredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the BasePredictor class for prediction based on a detection model.
|
|
9
|
+
"""A class extending the BasePredictor class for prediction based on a detection model.
|
|
11
10
|
|
|
12
11
|
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
|
13
12
|
with bounding boxes and class predictions.
|
|
@@ -32,8 +31,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
32
31
|
"""
|
|
33
32
|
|
|
34
33
|
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
|
35
|
-
"""
|
|
36
|
-
Post-process predictions and return a list of Results objects.
|
|
34
|
+
"""Post-process predictions and return a list of Results objects.
|
|
37
35
|
|
|
38
36
|
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
|
39
37
|
further analysis.
|
|
@@ -67,7 +65,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
67
65
|
)
|
|
68
66
|
|
|
69
67
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
70
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
68
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
71
69
|
|
|
72
70
|
if save_feats:
|
|
73
71
|
obj_feats = self.get_obj_feats(self._feats, preds[1])
|
|
@@ -81,7 +79,8 @@ class DetectionPredictor(BasePredictor):
|
|
|
81
79
|
|
|
82
80
|
return results
|
|
83
81
|
|
|
84
|
-
|
|
82
|
+
@staticmethod
|
|
83
|
+
def get_obj_feats(feat_maps, idxs):
|
|
85
84
|
"""Extract object features from the feature maps."""
|
|
86
85
|
import torch
|
|
87
86
|
|
|
@@ -89,11 +88,10 @@ class DetectionPredictor(BasePredictor):
|
|
|
89
88
|
obj_feats = torch.cat(
|
|
90
89
|
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
|
|
91
90
|
) # mean reduce all vectors to same length
|
|
92
|
-
return [feats[idx] if
|
|
91
|
+
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
|
93
92
|
|
|
94
93
|
def construct_results(self, preds, img, orig_imgs):
|
|
95
|
-
"""
|
|
96
|
-
Construct a list of Results objects from model predictions.
|
|
94
|
+
"""Construct a list of Results objects from model predictions.
|
|
97
95
|
|
|
98
96
|
Args:
|
|
99
97
|
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
|
@@ -109,8 +107,7 @@ class DetectionPredictor(BasePredictor):
|
|
|
109
107
|
]
|
|
110
108
|
|
|
111
109
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
112
|
-
"""
|
|
113
|
-
Construct a single Results object from one image prediction.
|
|
110
|
+
"""Construct a single Results object from one image prediction.
|
|
114
111
|
|
|
115
112
|
Args:
|
|
116
113
|
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
|
@@ -17,16 +17,15 @@ from ultralytics.models import yolo
|
|
|
17
17
|
from ultralytics.nn.tasks import DetectionModel
|
|
18
18
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
|
19
19
|
from ultralytics.utils.patches import override_configs
|
|
20
|
-
from ultralytics.utils.plotting import plot_images, plot_labels
|
|
20
|
+
from ultralytics.utils.plotting import plot_images, plot_labels
|
|
21
21
|
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class DetectionTrainer(BaseTrainer):
|
|
25
|
-
"""
|
|
26
|
-
A class extending the BaseTrainer class for training based on a detection model.
|
|
25
|
+
"""A class extending the BaseTrainer class for training based on a detection model.
|
|
27
26
|
|
|
28
|
-
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
|
29
|
-
|
|
27
|
+
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
|
|
28
|
+
object detection including dataset building, data loading, preprocessing, and model configuration.
|
|
30
29
|
|
|
31
30
|
Attributes:
|
|
32
31
|
model (DetectionModel): The YOLO detection model being trained.
|
|
@@ -43,7 +42,6 @@ class DetectionTrainer(BaseTrainer):
|
|
|
43
42
|
label_loss_items: Return a loss dictionary with labeled training loss items.
|
|
44
43
|
progress_string: Return a formatted string of training progress.
|
|
45
44
|
plot_training_samples: Plot training samples with their annotations.
|
|
46
|
-
plot_metrics: Plot metrics from a CSV file.
|
|
47
45
|
plot_training_labels: Create a labeled training plot of the YOLO model.
|
|
48
46
|
auto_batch: Calculate optimal batch size based on model memory requirements.
|
|
49
47
|
|
|
@@ -55,8 +53,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
55
53
|
"""
|
|
56
54
|
|
|
57
55
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
58
|
-
"""
|
|
59
|
-
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
|
56
|
+
"""Initialize a DetectionTrainer object for training YOLO object detection models.
|
|
60
57
|
|
|
61
58
|
Args:
|
|
62
59
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -64,11 +61,9 @@ class DetectionTrainer(BaseTrainer):
|
|
|
64
61
|
_callbacks (list, optional): List of callback functions to be executed during training.
|
|
65
62
|
"""
|
|
66
63
|
super().__init__(cfg, overrides, _callbacks)
|
|
67
|
-
self.dynamic_tensors = ["batch_idx", "cls", "bboxes"]
|
|
68
64
|
|
|
69
65
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
70
|
-
"""
|
|
71
|
-
Build YOLO Dataset for training or validation.
|
|
66
|
+
"""Build YOLO Dataset for training or validation.
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
69
|
img_path (str): Path to the folder containing images.
|
|
@@ -82,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
82
77
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
|
83
78
|
|
|
84
79
|
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
85
|
-
"""
|
|
86
|
-
Construct and return dataloader for the specified mode.
|
|
80
|
+
"""Construct and return dataloader for the specified mode.
|
|
87
81
|
|
|
88
82
|
Args:
|
|
89
83
|
dataset_path (str): Path to the dataset.
|
|
@@ -111,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
111
105
|
)
|
|
112
106
|
|
|
113
107
|
def preprocess_batch(self, batch: dict) -> dict:
|
|
114
|
-
"""
|
|
115
|
-
Preprocess a batch of images by scaling and converting to float.
|
|
108
|
+
"""Preprocess a batch of images by scaling and converting to float.
|
|
116
109
|
|
|
117
110
|
Args:
|
|
118
111
|
batch (dict): Dictionary containing batch data with 'img' tensor.
|
|
@@ -122,7 +115,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
122
115
|
"""
|
|
123
116
|
for k, v in batch.items():
|
|
124
117
|
if isinstance(v, torch.Tensor):
|
|
125
|
-
batch[k] = v.to(self.device, non_blocking=
|
|
118
|
+
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
|
126
119
|
batch["img"] = batch["img"].float() / 255
|
|
127
120
|
if self.args.multi_scale:
|
|
128
121
|
imgs = batch["img"]
|
|
@@ -138,10 +131,6 @@ class DetectionTrainer(BaseTrainer):
|
|
|
138
131
|
] # new shape (stretched to gs-multiple)
|
|
139
132
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
|
140
133
|
batch["img"] = imgs
|
|
141
|
-
|
|
142
|
-
if self.args.compile:
|
|
143
|
-
for k in self.dynamic_tensors:
|
|
144
|
-
torch._dynamo.maybe_mark_dynamic(batch[k], 0)
|
|
145
134
|
return batch
|
|
146
135
|
|
|
147
136
|
def set_model_attributes(self):
|
|
@@ -156,8 +145,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
156
145
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
|
157
146
|
|
|
158
147
|
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
|
159
|
-
"""
|
|
160
|
-
Return a YOLO detection model.
|
|
148
|
+
"""Return a YOLO detection model.
|
|
161
149
|
|
|
162
150
|
Args:
|
|
163
151
|
cfg (str, optional): Path to model configuration file.
|
|
@@ -180,8 +168,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
180
168
|
)
|
|
181
169
|
|
|
182
170
|
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
|
183
|
-
"""
|
|
184
|
-
Return a loss dict with labeled training loss items tensor.
|
|
171
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
185
172
|
|
|
186
173
|
Args:
|
|
187
174
|
loss_items (list[float], optional): List of loss values.
|
|
@@ -208,8 +195,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
208
195
|
)
|
|
209
196
|
|
|
210
197
|
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
211
|
-
"""
|
|
212
|
-
Plot training samples with their annotations.
|
|
198
|
+
"""Plot training samples with their annotations.
|
|
213
199
|
|
|
214
200
|
Args:
|
|
215
201
|
batch (dict[str, Any]): Dictionary containing batch data.
|
|
@@ -222,10 +208,6 @@ class DetectionTrainer(BaseTrainer):
|
|
|
222
208
|
on_plot=self.on_plot,
|
|
223
209
|
)
|
|
224
210
|
|
|
225
|
-
def plot_metrics(self):
|
|
226
|
-
"""Plot metrics from a CSV file."""
|
|
227
|
-
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
|
228
|
-
|
|
229
211
|
def plot_training_labels(self):
|
|
230
212
|
"""Create a labeled training plot of the YOLO model."""
|
|
231
213
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
|
@@ -233,8 +215,7 @@ class DetectionTrainer(BaseTrainer):
|
|
|
233
215
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
|
234
216
|
|
|
235
217
|
def auto_batch(self):
|
|
236
|
-
"""
|
|
237
|
-
Get optimal batch size by calculating memory occupation of model.
|
|
218
|
+
"""Get optimal batch size by calculating memory occupation of model.
|
|
238
219
|
|
|
239
220
|
Returns:
|
|
240
221
|
(int): Optimal batch size.
|