dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -44,7 +44,7 @@ class ClassificationPredictor(BasePredictor):
|
|
44
44
|
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
45
45
|
|
46
46
|
Args:
|
47
|
-
cfg (dict): Default configuration dictionary containing prediction settings.
|
47
|
+
cfg (dict): Default configuration dictionary containing prediction settings.
|
48
48
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
49
49
|
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
50
50
|
"""
|
@@ -53,7 +53,7 @@ class ClassificationPredictor(BasePredictor):
|
|
53
53
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
54
54
|
|
55
55
|
def setup_source(self, source):
|
56
|
-
"""
|
56
|
+
"""Set up source and inference mode and classify transforms."""
|
57
57
|
super().setup_source(source)
|
58
58
|
updated = (
|
59
59
|
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
@@ -68,14 +68,14 @@ class ClassificationPredictor(BasePredictor):
|
|
68
68
|
is_legacy_transform = any(
|
69
69
|
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
70
70
|
)
|
71
|
-
if is_legacy_transform: #
|
71
|
+
if is_legacy_transform: # Handle legacy transforms
|
72
72
|
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
73
73
|
else:
|
74
74
|
img = torch.stack(
|
75
75
|
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
76
76
|
)
|
77
77
|
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
78
|
-
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
78
|
+
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
79
79
|
|
80
80
|
def postprocess(self, preds, img, orig_imgs):
|
81
81
|
"""
|
@@ -89,7 +89,7 @@ class ClassificationPredictor(BasePredictor):
|
|
89
89
|
Returns:
|
90
90
|
(List[Results]): List of Results objects containing classification results for each image.
|
91
91
|
"""
|
92
|
-
if not isinstance(orig_imgs, list): #
|
92
|
+
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
93
93
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
94
94
|
|
95
95
|
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import copy
|
4
|
+
from typing import Any, Dict, Optional
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
@@ -15,14 +16,14 @@ from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_di
|
|
15
16
|
|
16
17
|
class ClassificationTrainer(BaseTrainer):
|
17
18
|
"""
|
18
|
-
A class extending
|
19
|
+
A trainer class extending BaseTrainer for training image classification models.
|
19
20
|
|
20
21
|
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
21
|
-
and torchvision models.
|
22
|
+
and torchvision models with comprehensive dataset handling and validation.
|
22
23
|
|
23
24
|
Attributes:
|
24
25
|
model (ClassificationModel): The classification model to be trained.
|
25
|
-
data (
|
26
|
+
data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
26
27
|
loss_names (List[str]): Names of the loss functions used during training.
|
27
28
|
validator (ClassificationValidator): Validator instance for model evaluation.
|
28
29
|
|
@@ -41,13 +42,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
41
42
|
plot_training_samples: Plot training samples with their annotations.
|
42
43
|
|
43
44
|
Examples:
|
45
|
+
Initialize and train a classification model
|
44
46
|
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
45
47
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
46
48
|
>>> trainer = ClassificationTrainer(overrides=args)
|
47
49
|
>>> trainer.train()
|
48
50
|
"""
|
49
51
|
|
50
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
52
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
|
51
53
|
"""
|
52
54
|
Initialize a ClassificationTrainer object.
|
53
55
|
|
@@ -55,11 +57,12 @@ class ClassificationTrainer(BaseTrainer):
|
|
55
57
|
image size if not specified.
|
56
58
|
|
57
59
|
Args:
|
58
|
-
cfg (
|
59
|
-
overrides (
|
60
|
-
_callbacks (
|
60
|
+
cfg (Dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
61
|
+
overrides (Dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
62
|
+
_callbacks (List[Any], optional): List of callback functions to be executed during training.
|
61
63
|
|
62
64
|
Examples:
|
65
|
+
Create a trainer with custom configuration
|
63
66
|
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
64
67
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
65
68
|
>>> trainer = ClassificationTrainer(overrides=args)
|
@@ -76,14 +79,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
76
79
|
"""Set the YOLO model's class names from the loaded dataset."""
|
77
80
|
self.model.names = self.data["names"]
|
78
81
|
|
79
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
82
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
80
83
|
"""
|
81
|
-
Return a modified PyTorch model configured for training YOLO.
|
84
|
+
Return a modified PyTorch model configured for training YOLO classification.
|
82
85
|
|
83
86
|
Args:
|
84
|
-
cfg (Any): Model configuration.
|
85
|
-
weights (Any): Pre-trained model weights.
|
86
|
-
verbose (bool): Whether to display model information.
|
87
|
+
cfg (Any, optional): Model configuration.
|
88
|
+
weights (Any, optional): Pre-trained model weights.
|
89
|
+
verbose (bool, optional): Whether to display model information.
|
87
90
|
|
88
91
|
Returns:
|
89
92
|
(ClassificationModel): Configured PyTorch model for classification.
|
@@ -120,29 +123,29 @@ class ClassificationTrainer(BaseTrainer):
|
|
120
123
|
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
121
124
|
return ckpt
|
122
125
|
|
123
|
-
def build_dataset(self, img_path, mode="train", batch=None):
|
126
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
124
127
|
"""
|
125
128
|
Create a ClassificationDataset instance given an image path and mode.
|
126
129
|
|
127
130
|
Args:
|
128
131
|
img_path (str): Path to the dataset images.
|
129
|
-
mode (str): Dataset mode ('train', 'val', or 'test').
|
130
|
-
batch (Any): Batch information (unused in this implementation).
|
132
|
+
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
133
|
+
batch (Any, optional): Batch information (unused in this implementation).
|
131
134
|
|
132
135
|
Returns:
|
133
136
|
(ClassificationDataset): Dataset for the specified mode.
|
134
137
|
"""
|
135
138
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
136
139
|
|
137
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
140
|
+
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
138
141
|
"""
|
139
142
|
Return PyTorch DataLoader with transforms to preprocess images.
|
140
143
|
|
141
144
|
Args:
|
142
145
|
dataset_path (str): Path to the dataset.
|
143
|
-
batch_size (int): Number of images per batch.
|
144
|
-
rank (int): Process rank for distributed training.
|
145
|
-
mode (str): 'train', 'val', or 'test' mode.
|
146
|
+
batch_size (int, optional): Number of images per batch.
|
147
|
+
rank (int, optional): Process rank for distributed training.
|
148
|
+
mode (str, optional): 'train', 'val', or 'test' mode.
|
146
149
|
|
147
150
|
Returns:
|
148
151
|
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
@@ -159,14 +162,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
159
162
|
self.model.transforms = loader.dataset.torch_transforms
|
160
163
|
return loader
|
161
164
|
|
162
|
-
def preprocess_batch(self, batch):
|
163
|
-
"""
|
165
|
+
def preprocess_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
166
|
+
"""Preprocess a batch of images and classes."""
|
164
167
|
batch["img"] = batch["img"].to(self.device)
|
165
168
|
batch["cls"] = batch["cls"].to(self.device)
|
166
169
|
return batch
|
167
170
|
|
168
|
-
def progress_string(self):
|
169
|
-
"""
|
171
|
+
def progress_string(self) -> str:
|
172
|
+
"""Return a formatted string showing training progress."""
|
170
173
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
171
174
|
"Epoch",
|
172
175
|
"GPU_mem",
|
@@ -176,22 +179,23 @@ class ClassificationTrainer(BaseTrainer):
|
|
176
179
|
)
|
177
180
|
|
178
181
|
def get_validator(self):
|
179
|
-
"""
|
182
|
+
"""Return an instance of ClassificationValidator for validation."""
|
180
183
|
self.loss_names = ["loss"]
|
181
184
|
return yolo.classify.ClassificationValidator(
|
182
185
|
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
183
186
|
)
|
184
187
|
|
185
|
-
def label_loss_items(self, loss_items=None, prefix="train"):
|
188
|
+
def label_loss_items(self, loss_items: Optional[torch.Tensor] = None, prefix: str = "train"):
|
186
189
|
"""
|
187
190
|
Return a loss dict with labelled training loss items tensor.
|
188
191
|
|
189
192
|
Args:
|
190
193
|
loss_items (torch.Tensor, optional): Loss tensor items.
|
191
|
-
prefix (str): Prefix to prepend to loss names.
|
194
|
+
prefix (str, optional): Prefix to prepend to loss names.
|
192
195
|
|
193
196
|
Returns:
|
194
|
-
(
|
197
|
+
keys (List[str]): List of loss keys if loss_items is None.
|
198
|
+
loss_dict (Dict[str, float]): Dictionary of loss items if loss_items is provided.
|
195
199
|
"""
|
196
200
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
197
201
|
if loss_items is None:
|
@@ -216,7 +220,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
216
220
|
self.metrics.pop("fitness", None)
|
217
221
|
self.run_callbacks("on_fit_epoch_end")
|
218
222
|
|
219
|
-
def plot_training_samples(self, batch, ni):
|
223
|
+
def plot_training_samples(self, batch: Dict[str, torch.Tensor], ni: int):
|
220
224
|
"""
|
221
225
|
Plot training samples with their annotations.
|
222
226
|
|
@@ -52,9 +52,6 @@ class ClassificationValidator(BaseValidator):
|
|
52
52
|
"""
|
53
53
|
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
54
54
|
|
55
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
56
|
-
confusion matrix generation, and visualization of results.
|
57
|
-
|
58
55
|
Args:
|
59
56
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
60
57
|
save_dir (str | Path, optional): Directory to save results.
|
@@ -101,8 +98,9 @@ class ClassificationValidator(BaseValidator):
|
|
101
98
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
102
99
|
batch (dict): Batch data containing images and class labels.
|
103
100
|
|
104
|
-
|
105
|
-
|
101
|
+
Notes:
|
102
|
+
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
103
|
+
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
106
104
|
"""
|
107
105
|
n5 = min(len(self.names), 5)
|
108
106
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
@@ -112,13 +110,14 @@ class ClassificationValidator(BaseValidator):
|
|
112
110
|
"""
|
113
111
|
Finalize metrics including confusion matrix and processing speed.
|
114
112
|
|
115
|
-
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
116
|
-
optionally plots it, and updates the metrics object with speed information.
|
117
|
-
|
118
113
|
Args:
|
119
114
|
*args (Any): Variable length argument list.
|
120
115
|
**kwargs (Any): Arbitrary keyword arguments.
|
121
116
|
|
117
|
+
Notes:
|
118
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
119
|
+
optionally plots it, and updates the metrics object with speed information.
|
120
|
+
|
122
121
|
Examples:
|
123
122
|
>>> validator = ClassificationValidator()
|
124
123
|
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
@@ -21,6 +21,7 @@ class DetectionPredictor(BasePredictor):
|
|
21
21
|
postprocess: Process raw model predictions into detection results.
|
22
22
|
construct_results: Build Results objects from processed predictions.
|
23
23
|
construct_result: Create a single Result object from a prediction.
|
24
|
+
get_obj_feats: Extract object features from the feature maps.
|
24
25
|
|
25
26
|
Examples:
|
26
27
|
>>> from ultralytics.utils import ASSETS
|
@@ -3,6 +3,7 @@
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
from copy import copy
|
6
|
+
from typing import Dict, List, Optional
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
import torch.nn as nn
|
@@ -21,12 +22,12 @@ class DetectionTrainer(BaseTrainer):
|
|
21
22
|
A class extending the BaseTrainer class for training based on a detection model.
|
22
23
|
|
23
24
|
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
24
|
-
for object detection.
|
25
|
+
for object detection including dataset building, data loading, preprocessing, and model configuration.
|
25
26
|
|
26
27
|
Attributes:
|
27
28
|
model (DetectionModel): The YOLO detection model being trained.
|
28
|
-
data (
|
29
|
-
loss_names (
|
29
|
+
data (Dict): Dictionary containing dataset information including class names and number of classes.
|
30
|
+
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
30
31
|
|
31
32
|
Methods:
|
32
33
|
build_dataset: Build YOLO dataset for training or validation.
|
@@ -49,14 +50,14 @@ class DetectionTrainer(BaseTrainer):
|
|
49
50
|
>>> trainer.train()
|
50
51
|
"""
|
51
52
|
|
52
|
-
def build_dataset(self, img_path, mode="train", batch=None):
|
53
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
|
53
54
|
"""
|
54
55
|
Build YOLO Dataset for training or validation.
|
55
56
|
|
56
57
|
Args:
|
57
58
|
img_path (str): Path to the folder containing images.
|
58
|
-
mode (str):
|
59
|
-
batch (int, optional): Size of batches, this is for
|
59
|
+
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
60
|
+
batch (int, optional): Size of batches, this is for 'rect' mode.
|
60
61
|
|
61
62
|
Returns:
|
62
63
|
(Dataset): YOLO dataset object configured for the specified mode.
|
@@ -64,7 +65,7 @@ class DetectionTrainer(BaseTrainer):
|
|
64
65
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
65
66
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
66
67
|
|
67
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
68
|
+
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
68
69
|
"""
|
69
70
|
Construct and return dataloader for the specified mode.
|
70
71
|
|
@@ -87,15 +88,15 @@ class DetectionTrainer(BaseTrainer):
|
|
87
88
|
workers = self.args.workers if mode == "train" else self.args.workers * 2
|
88
89
|
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
89
90
|
|
90
|
-
def preprocess_batch(self, batch):
|
91
|
+
def preprocess_batch(self, batch: Dict) -> Dict:
|
91
92
|
"""
|
92
93
|
Preprocess a batch of images by scaling and converting to float.
|
93
94
|
|
94
95
|
Args:
|
95
|
-
batch (
|
96
|
+
batch (Dict): Dictionary containing batch data with 'img' tensor.
|
96
97
|
|
97
98
|
Returns:
|
98
|
-
(
|
99
|
+
(Dict): Preprocessed batch with normalized images.
|
99
100
|
"""
|
100
101
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
101
102
|
if self.args.multi_scale:
|
@@ -125,7 +126,7 @@ class DetectionTrainer(BaseTrainer):
|
|
125
126
|
self.model.args = self.args # attach hyperparameters to model
|
126
127
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
127
128
|
|
128
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
129
|
+
def get_model(self, cfg: Optional[str] = None, weights: Optional[str] = None, verbose: bool = True):
|
129
130
|
"""
|
130
131
|
Return a YOLO detection model.
|
131
132
|
|
@@ -149,7 +150,7 @@ class DetectionTrainer(BaseTrainer):
|
|
149
150
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
150
151
|
)
|
151
152
|
|
152
|
-
def label_loss_items(self, loss_items=None, prefix="train"):
|
153
|
+
def label_loss_items(self, loss_items: Optional[List[float]] = None, prefix: str = "train"):
|
153
154
|
"""
|
154
155
|
Return a loss dict with labeled training loss items tensor.
|
155
156
|
|
@@ -177,12 +178,12 @@ class DetectionTrainer(BaseTrainer):
|
|
177
178
|
"Size",
|
178
179
|
)
|
179
180
|
|
180
|
-
def plot_training_samples(self, batch, ni):
|
181
|
+
def plot_training_samples(self, batch: Dict, ni: int):
|
181
182
|
"""
|
182
183
|
Plot training samples with their annotations.
|
183
184
|
|
184
185
|
Args:
|
185
|
-
batch (
|
186
|
+
batch (Dict): Dictionary containing batch data.
|
186
187
|
ni (int): Number of iterations.
|
187
188
|
"""
|
188
189
|
plot_images(
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
import torch
|
@@ -26,13 +27,13 @@ class DetectionValidator(BaseValidator):
|
|
26
27
|
nt_per_image (np.ndarray): Number of targets per image.
|
27
28
|
is_coco (bool): Whether the dataset is COCO.
|
28
29
|
is_lvis (bool): Whether the dataset is LVIS.
|
29
|
-
class_map (
|
30
|
+
class_map (List[int]): Mapping from model class indices to dataset class indices.
|
30
31
|
metrics (DetMetrics): Object detection metrics calculator.
|
31
32
|
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
32
33
|
niou (int): Number of IoU thresholds.
|
33
|
-
lb (
|
34
|
-
jdict (
|
35
|
-
stats (
|
34
|
+
lb (List[Any]): List for storing ground truth labels for hybrid saving.
|
35
|
+
jdict (List[Dict[str, Any]]): List for storing JSON detection results.
|
36
|
+
stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.
|
36
37
|
|
37
38
|
Examples:
|
38
39
|
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
@@ -49,8 +50,8 @@ class DetectionValidator(BaseValidator):
|
|
49
50
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
50
51
|
save_dir (Path, optional): Directory to save results.
|
51
52
|
pbar (Any, optional): Progress bar for displaying progress.
|
52
|
-
args (
|
53
|
-
_callbacks (
|
53
|
+
args (Dict[str, Any], optional): Arguments for the validator.
|
54
|
+
_callbacks (List[Any], optional): List of callback functions.
|
54
55
|
"""
|
55
56
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
56
57
|
self.nt_per_class = None
|
@@ -63,15 +64,15 @@ class DetectionValidator(BaseValidator):
|
|
63
64
|
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
64
65
|
self.niou = self.iouv.numel()
|
65
66
|
|
66
|
-
def preprocess(self, batch):
|
67
|
+
def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
67
68
|
"""
|
68
69
|
Preprocess batch of images for YOLO validation.
|
69
70
|
|
70
71
|
Args:
|
71
|
-
batch (
|
72
|
+
batch (Dict[str, Any]): Batch containing images and annotations.
|
72
73
|
|
73
74
|
Returns:
|
74
|
-
(
|
75
|
+
(Dict[str, Any]): Preprocessed batch.
|
75
76
|
"""
|
76
77
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
77
78
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
@@ -80,7 +81,7 @@ class DetectionValidator(BaseValidator):
|
|
80
81
|
|
81
82
|
return batch
|
82
83
|
|
83
|
-
def init_metrics(self, model):
|
84
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
84
85
|
"""
|
85
86
|
Initialize evaluation metrics for YOLO detection validation.
|
86
87
|
|
@@ -106,11 +107,11 @@ class DetectionValidator(BaseValidator):
|
|
106
107
|
self.jdict = []
|
107
108
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
108
109
|
|
109
|
-
def get_desc(self):
|
110
|
+
def get_desc(self) -> str:
|
110
111
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
111
112
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
112
113
|
|
113
|
-
def postprocess(self, preds):
|
114
|
+
def postprocess(self, preds: torch.Tensor) -> List[torch.Tensor]:
|
114
115
|
"""
|
115
116
|
Apply Non-maximum suppression to prediction outputs.
|
116
117
|
|
@@ -132,16 +133,16 @@ class DetectionValidator(BaseValidator):
|
|
132
133
|
rotated=self.args.task == "obb",
|
133
134
|
)
|
134
135
|
|
135
|
-
def _prepare_batch(self, si, batch):
|
136
|
+
def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
|
136
137
|
"""
|
137
138
|
Prepare a batch of images and annotations for validation.
|
138
139
|
|
139
140
|
Args:
|
140
141
|
si (int): Batch index.
|
141
|
-
batch (
|
142
|
+
batch (Dict[str, Any]): Batch data containing images and annotations.
|
142
143
|
|
143
144
|
Returns:
|
144
|
-
(
|
145
|
+
(Dict[str, Any]): Prepared batch with processed annotations.
|
145
146
|
"""
|
146
147
|
idx = batch["batch_idx"] == si
|
147
148
|
cls = batch["cls"][idx].squeeze(-1)
|
@@ -154,13 +155,13 @@ class DetectionValidator(BaseValidator):
|
|
154
155
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
155
156
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
156
157
|
|
157
|
-
def _prepare_pred(self, pred, pbatch):
|
158
|
+
def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> torch.Tensor:
|
158
159
|
"""
|
159
160
|
Prepare predictions for evaluation against ground truth.
|
160
161
|
|
161
162
|
Args:
|
162
163
|
pred (torch.Tensor): Model predictions.
|
163
|
-
pbatch (
|
164
|
+
pbatch (Dict[str, Any]): Prepared batch information.
|
164
165
|
|
165
166
|
Returns:
|
166
167
|
(torch.Tensor): Prepared predictions in native space.
|
@@ -171,13 +172,13 @@ class DetectionValidator(BaseValidator):
|
|
171
172
|
) # native-space pred
|
172
173
|
return predn
|
173
174
|
|
174
|
-
def update_metrics(self, preds, batch):
|
175
|
+
def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]) -> None:
|
175
176
|
"""
|
176
177
|
Update metrics with new predictions and ground truth.
|
177
178
|
|
178
179
|
Args:
|
179
180
|
preds (List[torch.Tensor]): List of predictions from the model.
|
180
|
-
batch (
|
181
|
+
batch (Dict[str, Any]): Batch data containing ground truth.
|
181
182
|
"""
|
182
183
|
for si, pred in enumerate(preds):
|
183
184
|
self.seen += 1
|
@@ -226,7 +227,7 @@ class DetectionValidator(BaseValidator):
|
|
226
227
|
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
227
228
|
)
|
228
229
|
|
229
|
-
def finalize_metrics(self, *args, **kwargs):
|
230
|
+
def finalize_metrics(self, *args: Any, **kwargs: Any) -> None:
|
230
231
|
"""
|
231
232
|
Set final values for metrics speed and confusion matrix.
|
232
233
|
|
@@ -237,12 +238,12 @@ class DetectionValidator(BaseValidator):
|
|
237
238
|
self.metrics.speed = self.speed
|
238
239
|
self.metrics.confusion_matrix = self.confusion_matrix
|
239
240
|
|
240
|
-
def get_stats(self):
|
241
|
+
def get_stats(self) -> Dict[str, Any]:
|
241
242
|
"""
|
242
243
|
Calculate and return metrics statistics.
|
243
244
|
|
244
245
|
Returns:
|
245
|
-
(
|
246
|
+
(Dict[str, Any]): Dictionary containing metrics results.
|
246
247
|
"""
|
247
248
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
248
249
|
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
|
@@ -252,7 +253,7 @@ class DetectionValidator(BaseValidator):
|
|
252
253
|
self.metrics.process(**stats, on_plot=self.on_plot)
|
253
254
|
return self.metrics.results_dict
|
254
255
|
|
255
|
-
def print_results(self):
|
256
|
+
def print_results(self) -> None:
|
256
257
|
"""Print training/validation set metrics per class."""
|
257
258
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
258
259
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
@@ -272,7 +273,7 @@ class DetectionValidator(BaseValidator):
|
|
272
273
|
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
273
274
|
)
|
274
275
|
|
275
|
-
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
276
|
+
def _process_batch(self, detections: torch.Tensor, gt_bboxes: torch.Tensor, gt_cls: torch.Tensor) -> torch.Tensor:
|
276
277
|
"""
|
277
278
|
Return correct prediction matrix.
|
278
279
|
|
@@ -289,7 +290,7 @@ class DetectionValidator(BaseValidator):
|
|
289
290
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
290
291
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
291
292
|
|
292
|
-
def build_dataset(self, img_path, mode="val", batch=None):
|
293
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
|
293
294
|
"""
|
294
295
|
Build YOLO Dataset.
|
295
296
|
|
@@ -303,7 +304,7 @@ class DetectionValidator(BaseValidator):
|
|
303
304
|
"""
|
304
305
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
305
306
|
|
306
|
-
def get_dataloader(self, dataset_path, batch_size):
|
307
|
+
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
307
308
|
"""
|
308
309
|
Construct and return dataloader.
|
309
310
|
|
@@ -317,12 +318,12 @@ class DetectionValidator(BaseValidator):
|
|
317
318
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
318
319
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
319
320
|
|
320
|
-
def plot_val_samples(self, batch, ni):
|
321
|
+
def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
|
321
322
|
"""
|
322
323
|
Plot validation image samples.
|
323
324
|
|
324
325
|
Args:
|
325
|
-
batch (
|
326
|
+
batch (Dict[str, Any]): Batch containing images and annotations.
|
326
327
|
ni (int): Batch index.
|
327
328
|
"""
|
328
329
|
plot_images(
|
@@ -336,12 +337,12 @@ class DetectionValidator(BaseValidator):
|
|
336
337
|
on_plot=self.on_plot,
|
337
338
|
)
|
338
339
|
|
339
|
-
def plot_predictions(self, batch, preds, ni):
|
340
|
+
def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
|
340
341
|
"""
|
341
342
|
Plot predicted bounding boxes on input images and save the result.
|
342
343
|
|
343
344
|
Args:
|
344
|
-
batch (
|
345
|
+
batch (Dict[str, Any]): Batch containing images and annotations.
|
345
346
|
preds (List[torch.Tensor]): List of predictions from the model.
|
346
347
|
ni (int): Batch index.
|
347
348
|
"""
|
@@ -354,14 +355,14 @@ class DetectionValidator(BaseValidator):
|
|
354
355
|
on_plot=self.on_plot,
|
355
356
|
) # pred
|
356
357
|
|
357
|
-
def save_one_txt(self, predn, save_conf, shape, file):
|
358
|
+
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:
|
358
359
|
"""
|
359
360
|
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
360
361
|
|
361
362
|
Args:
|
362
363
|
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
363
364
|
save_conf (bool): Whether to save confidence scores.
|
364
|
-
shape (
|
365
|
+
shape (Tuple[int, int]): Shape of the original image.
|
365
366
|
file (Path): File path to save the detections.
|
366
367
|
"""
|
367
368
|
from ultralytics.engine.results import Results
|
@@ -373,7 +374,7 @@ class DetectionValidator(BaseValidator):
|
|
373
374
|
boxes=predn[:, :6],
|
374
375
|
).save_txt(file, save_conf=save_conf)
|
375
376
|
|
376
|
-
def pred_to_json(self, predn, filename):
|
377
|
+
def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:
|
377
378
|
"""
|
378
379
|
Serialize YOLO predictions to COCO json format.
|
379
380
|
|
@@ -395,15 +396,15 @@ class DetectionValidator(BaseValidator):
|
|
395
396
|
}
|
396
397
|
)
|
397
398
|
|
398
|
-
def eval_json(self, stats):
|
399
|
+
def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
399
400
|
"""
|
400
401
|
Evaluate YOLO output in JSON format and return performance statistics.
|
401
402
|
|
402
403
|
Args:
|
403
|
-
stats (
|
404
|
+
stats (Dict[str, Any]): Current statistics dictionary.
|
404
405
|
|
405
406
|
Returns:
|
406
|
-
(
|
407
|
+
(Dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
|
407
408
|
"""
|
408
409
|
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
409
410
|
pred_json = self.save_dir / "predictions.json" # predictions
|