ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +13 -19
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +64 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -13,18 +13,43 @@ class ClassificationValidator(BaseValidator):
|
|
13
13
|
"""
|
14
14
|
A class extending the BaseValidator class for validation based on a classification model.
|
15
15
|
|
16
|
-
|
17
|
-
|
16
|
+
This validator handles the validation process for classification models, including metrics calculation,
|
17
|
+
confusion matrix generation, and visualization of results.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
targets (List[torch.Tensor]): Ground truth class labels.
|
21
|
+
pred (List[torch.Tensor]): Model predictions.
|
22
|
+
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
23
|
+
names (Dict): Mapping of class indices to class names.
|
24
|
+
nc (int): Number of classes.
|
25
|
+
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
26
|
+
|
27
|
+
Methods:
|
28
|
+
get_desc: Return a formatted string summarizing classification metrics.
|
29
|
+
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
30
|
+
preprocess: Preprocess input batch by moving data to device.
|
31
|
+
update_metrics: Update running metrics with model predictions and batch targets.
|
32
|
+
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
33
|
+
postprocess: Extract the primary prediction from model output.
|
34
|
+
get_stats: Calculate and return a dictionary of metrics.
|
35
|
+
build_dataset: Create a ClassificationDataset instance for validation.
|
36
|
+
get_dataloader: Build and return a data loader for classification validation.
|
37
|
+
print_results: Print evaluation metrics for the classification model.
|
38
|
+
plot_val_samples: Plot validation image samples with their ground truth labels.
|
39
|
+
plot_predictions: Plot images with their predicted class labels.
|
18
40
|
|
19
41
|
Examples:
|
20
42
|
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
21
43
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
22
44
|
>>> validator = ClassificationValidator(args=args)
|
23
45
|
>>> validator()
|
46
|
+
|
47
|
+
Notes:
|
48
|
+
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
24
49
|
"""
|
25
50
|
|
26
51
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
27
|
-
"""
|
52
|
+
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters."""
|
28
53
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
29
54
|
self.targets = None
|
30
55
|
self.pred = None
|
@@ -32,11 +57,11 @@ class ClassificationValidator(BaseValidator):
|
|
32
57
|
self.metrics = ClassifyMetrics()
|
33
58
|
|
34
59
|
def get_desc(self):
|
35
|
-
"""
|
60
|
+
"""Return a formatted string summarizing classification metrics."""
|
36
61
|
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
37
62
|
|
38
63
|
def init_metrics(self, model):
|
39
|
-
"""Initialize confusion matrix, class names, and
|
64
|
+
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
40
65
|
self.names = model.names
|
41
66
|
self.nc = len(model.names)
|
42
67
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
|
@@ -44,20 +69,20 @@ class ClassificationValidator(BaseValidator):
|
|
44
69
|
self.targets = []
|
45
70
|
|
46
71
|
def preprocess(self, batch):
|
47
|
-
"""
|
72
|
+
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
48
73
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
49
74
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
50
75
|
batch["cls"] = batch["cls"].to(self.device)
|
51
76
|
return batch
|
52
77
|
|
53
78
|
def update_metrics(self, preds, batch):
|
54
|
-
"""
|
79
|
+
"""Update running metrics with model predictions and batch targets."""
|
55
80
|
n5 = min(len(self.names), 5)
|
56
81
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
57
82
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
58
83
|
|
59
84
|
def finalize_metrics(self, *args, **kwargs):
|
60
|
-
"""
|
85
|
+
"""Finalize metrics including confusion matrix and processing speed."""
|
61
86
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
62
87
|
if self.args.plots:
|
63
88
|
for normalize in True, False:
|
@@ -69,30 +94,30 @@ class ClassificationValidator(BaseValidator):
|
|
69
94
|
self.metrics.save_dir = self.save_dir
|
70
95
|
|
71
96
|
def postprocess(self, preds):
|
72
|
-
"""
|
97
|
+
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
73
98
|
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
74
99
|
|
75
100
|
def get_stats(self):
|
76
|
-
"""
|
101
|
+
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
77
102
|
self.metrics.process(self.targets, self.pred)
|
78
103
|
return self.metrics.results_dict
|
79
104
|
|
80
105
|
def build_dataset(self, img_path):
|
81
|
-
"""
|
106
|
+
"""Create a ClassificationDataset instance for validation."""
|
82
107
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
83
108
|
|
84
109
|
def get_dataloader(self, dataset_path, batch_size):
|
85
|
-
"""
|
110
|
+
"""Build and return a data loader for classification validation."""
|
86
111
|
dataset = self.build_dataset(dataset_path)
|
87
112
|
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
88
113
|
|
89
114
|
def print_results(self):
|
90
|
-
"""
|
115
|
+
"""Print evaluation metrics for the classification model."""
|
91
116
|
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
92
117
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
93
118
|
|
94
119
|
def plot_val_samples(self, batch, ni):
|
95
|
-
"""Plot validation image samples."""
|
120
|
+
"""Plot validation image samples with their ground truth labels."""
|
96
121
|
plot_images(
|
97
122
|
images=batch["img"],
|
98
123
|
batch_idx=torch.arange(len(batch["img"])),
|
@@ -103,7 +128,7 @@ class ClassificationValidator(BaseValidator):
|
|
103
128
|
)
|
104
129
|
|
105
130
|
def plot_predictions(self, batch, preds, ni):
|
106
|
-
"""
|
131
|
+
"""Plot images with their predicted class labels and save the visualization."""
|
107
132
|
plot_images(
|
108
133
|
batch["img"],
|
109
134
|
batch_idx=torch.arange(len(batch["img"])),
|
@@ -9,6 +9,19 @@ class DetectionPredictor(BasePredictor):
|
|
9
9
|
"""
|
10
10
|
A class extending the BasePredictor class for prediction based on a detection model.
|
11
11
|
|
12
|
+
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
13
|
+
with bounding boxes and class predictions.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
args (namespace): Configuration arguments for the predictor.
|
17
|
+
model (nn.Module): The detection model used for inference.
|
18
|
+
batch (List): Batch of images and metadata for processing.
|
19
|
+
|
20
|
+
Methods:
|
21
|
+
postprocess: Process raw model predictions into detection results.
|
22
|
+
construct_results: Build Results objects from processed predictions.
|
23
|
+
construct_result: Create a single Result object from a prediction.
|
24
|
+
|
12
25
|
Examples:
|
13
26
|
>>> from ultralytics.utils import ASSETS
|
14
27
|
>>> from ultralytics.models.yolo.detect import DetectionPredictor
|
@@ -38,15 +51,15 @@ class DetectionPredictor(BasePredictor):
|
|
38
51
|
|
39
52
|
def construct_results(self, preds, img, orig_imgs):
|
40
53
|
"""
|
41
|
-
|
54
|
+
Construct a list of Results objects from model predictions.
|
42
55
|
|
43
56
|
Args:
|
44
|
-
preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
|
45
|
-
img (torch.Tensor):
|
57
|
+
preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
58
|
+
img (torch.Tensor): Batch of preprocessed images used for inference.
|
46
59
|
orig_imgs (List[np.ndarray]): List of original images before preprocessing.
|
47
60
|
|
48
61
|
Returns:
|
49
|
-
(
|
62
|
+
(List[Results]): List of Results objects containing detection information for each image.
|
50
63
|
"""
|
51
64
|
return [
|
52
65
|
self.construct_result(pred, img, orig_img, img_path)
|
@@ -55,16 +68,16 @@ class DetectionPredictor(BasePredictor):
|
|
55
68
|
|
56
69
|
def construct_result(self, pred, img, orig_img, img_path):
|
57
70
|
"""
|
58
|
-
|
71
|
+
Construct a single Results object from one image prediction.
|
59
72
|
|
60
73
|
Args:
|
61
|
-
pred (torch.Tensor):
|
62
|
-
img (torch.Tensor):
|
63
|
-
orig_img (np.ndarray):
|
64
|
-
img_path (str):
|
74
|
+
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
75
|
+
img (torch.Tensor): Preprocessed image tensor used for inference.
|
76
|
+
orig_img (np.ndarray): Original image before preprocessing.
|
77
|
+
img_path (str): Path to the original image file.
|
65
78
|
|
66
79
|
Returns:
|
67
|
-
(Results):
|
80
|
+
(Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
|
68
81
|
"""
|
69
82
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
70
83
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
|
@@ -20,6 +20,28 @@ class DetectionTrainer(BaseTrainer):
|
|
20
20
|
"""
|
21
21
|
A class extending the BaseTrainer class for training based on a detection model.
|
22
22
|
|
23
|
+
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
24
|
+
for object detection.
|
25
|
+
|
26
|
+
Attributes:
|
27
|
+
model (DetectionModel): The YOLO detection model being trained.
|
28
|
+
data (Dict): Dictionary containing dataset information including class names and number of classes.
|
29
|
+
loss_names (Tuple[str]): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
30
|
+
|
31
|
+
Methods:
|
32
|
+
build_dataset: Build YOLO dataset for training or validation.
|
33
|
+
get_dataloader: Construct and return dataloader for the specified mode.
|
34
|
+
preprocess_batch: Preprocess a batch of images by scaling and converting to float.
|
35
|
+
set_model_attributes: Set model attributes based on dataset information.
|
36
|
+
get_model: Return a YOLO detection model.
|
37
|
+
get_validator: Return a validator for model evaluation.
|
38
|
+
label_loss_items: Return a loss dictionary with labeled training loss items.
|
39
|
+
progress_string: Return a formatted string of training progress.
|
40
|
+
plot_training_samples: Plot training samples with their annotations.
|
41
|
+
plot_metrics: Plot metrics from a CSV file.
|
42
|
+
plot_training_labels: Create a labeled training plot of the YOLO model.
|
43
|
+
auto_batch: Calculate optimal batch size based on model memory requirements.
|
44
|
+
|
23
45
|
Examples:
|
24
46
|
>>> from ultralytics.models.yolo.detect import DetectionTrainer
|
25
47
|
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
|
@@ -29,18 +51,32 @@ class DetectionTrainer(BaseTrainer):
|
|
29
51
|
|
30
52
|
def build_dataset(self, img_path, mode="train", batch=None):
|
31
53
|
"""
|
32
|
-
Build YOLO Dataset.
|
54
|
+
Build YOLO Dataset for training or validation.
|
33
55
|
|
34
56
|
Args:
|
35
57
|
img_path (str): Path to the folder containing images.
|
36
58
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
37
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
59
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
(Dataset): YOLO dataset object configured for the specified mode.
|
38
63
|
"""
|
39
64
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
40
65
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
41
66
|
|
42
67
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
43
|
-
"""
|
68
|
+
"""
|
69
|
+
Construct and return dataloader for the specified mode.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
dataset_path (str): Path to the dataset.
|
73
|
+
batch_size (int): Number of images per batch.
|
74
|
+
rank (int): Process rank for distributed training.
|
75
|
+
mode (str): 'train' for training dataloader, 'val' for validation dataloader.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
(DataLoader): PyTorch dataloader object.
|
79
|
+
"""
|
44
80
|
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
45
81
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
46
82
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
@@ -52,7 +88,15 @@ class DetectionTrainer(BaseTrainer):
|
|
52
88
|
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
53
89
|
|
54
90
|
def preprocess_batch(self, batch):
|
55
|
-
"""
|
91
|
+
"""
|
92
|
+
Preprocess a batch of images by scaling and converting to float.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
batch (Dict): Dictionary containing batch data with 'img' tensor.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
(Dict): Preprocessed batch with normalized images.
|
99
|
+
"""
|
56
100
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
57
101
|
if self.args.multi_scale:
|
58
102
|
imgs = batch["img"]
|
@@ -71,7 +115,8 @@ class DetectionTrainer(BaseTrainer):
|
|
71
115
|
return batch
|
72
116
|
|
73
117
|
def set_model_attributes(self):
|
74
|
-
"""
|
118
|
+
"""Set model attributes based on dataset information."""
|
119
|
+
# Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
75
120
|
# self.args.box *= 3 / nl # scale to layers
|
76
121
|
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
77
122
|
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
@@ -81,14 +126,24 @@ class DetectionTrainer(BaseTrainer):
|
|
81
126
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
82
127
|
|
83
128
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
84
|
-
"""
|
129
|
+
"""
|
130
|
+
Return a YOLO detection model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
cfg (str, optional): Path to model configuration file.
|
134
|
+
weights (str, optional): Path to model weights.
|
135
|
+
verbose (bool): Whether to display model information.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
(DetectionModel): YOLO detection model.
|
139
|
+
"""
|
85
140
|
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
86
141
|
if weights:
|
87
142
|
model.load(weights)
|
88
143
|
return model
|
89
144
|
|
90
145
|
def get_validator(self):
|
91
|
-
"""
|
146
|
+
"""Return a DetectionValidator for YOLO model validation."""
|
92
147
|
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
93
148
|
return yolo.detect.DetectionValidator(
|
94
149
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
@@ -96,9 +151,14 @@ class DetectionTrainer(BaseTrainer):
|
|
96
151
|
|
97
152
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
98
153
|
"""
|
99
|
-
|
154
|
+
Return a loss dict with labeled training loss items tensor.
|
100
155
|
|
101
|
-
|
156
|
+
Args:
|
157
|
+
loss_items (List[float], optional): List of loss values.
|
158
|
+
prefix (str): Prefix for keys in the returned dictionary.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
(Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
|
102
162
|
"""
|
103
163
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
104
164
|
if loss_items is not None:
|
@@ -108,7 +168,7 @@ class DetectionTrainer(BaseTrainer):
|
|
108
168
|
return keys
|
109
169
|
|
110
170
|
def progress_string(self):
|
111
|
-
"""
|
171
|
+
"""Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
112
172
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
113
173
|
"Epoch",
|
114
174
|
"GPU_mem",
|
@@ -118,7 +178,13 @@ class DetectionTrainer(BaseTrainer):
|
|
118
178
|
)
|
119
179
|
|
120
180
|
def plot_training_samples(self, batch, ni):
|
121
|
-
"""
|
181
|
+
"""
|
182
|
+
Plot training samples with their annotations.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
batch (Dict): Dictionary containing batch data.
|
186
|
+
ni (int): Number of iterations.
|
187
|
+
"""
|
122
188
|
plot_images(
|
123
189
|
images=batch["img"],
|
124
190
|
batch_idx=batch["batch_idx"],
|
@@ -130,7 +196,7 @@ class DetectionTrainer(BaseTrainer):
|
|
130
196
|
)
|
131
197
|
|
132
198
|
def plot_metrics(self):
|
133
|
-
"""
|
199
|
+
"""Plot metrics from a CSV file."""
|
134
200
|
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
135
201
|
|
136
202
|
def plot_training_labels(self):
|
@@ -140,8 +206,12 @@ class DetectionTrainer(BaseTrainer):
|
|
140
206
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
141
207
|
|
142
208
|
def auto_batch(self):
|
143
|
-
"""
|
209
|
+
"""
|
210
|
+
Get optimal batch size by calculating memory occupation of model.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
(int): Optimal batch size.
|
214
|
+
"""
|
144
215
|
train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
|
145
|
-
# 4 for mosaic augmentation
|
146
|
-
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
|
216
|
+
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
|
147
217
|
return super().auto_batch(max_num_obj)
|
@@ -18,6 +18,22 @@ class DetectionValidator(BaseValidator):
|
|
18
18
|
"""
|
19
19
|
A class extending the BaseValidator class for validation based on a detection model.
|
20
20
|
|
21
|
+
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
22
|
+
prediction processing, and visualization of results.
|
23
|
+
|
24
|
+
Attributes:
|
25
|
+
nt_per_class (np.ndarray): Number of targets per class.
|
26
|
+
nt_per_image (np.ndarray): Number of targets per image.
|
27
|
+
is_coco (bool): Whether the dataset is COCO.
|
28
|
+
is_lvis (bool): Whether the dataset is LVIS.
|
29
|
+
class_map (List): Mapping from model class indices to dataset class indices.
|
30
|
+
metrics (DetMetrics): Object detection metrics calculator.
|
31
|
+
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
32
|
+
niou (int): Number of IoU thresholds.
|
33
|
+
lb (List): List for storing ground truth labels for hybrid saving.
|
34
|
+
jdict (List): List for storing JSON detection results.
|
35
|
+
stats (Dict): Dictionary for storing statistics during validation.
|
36
|
+
|
21
37
|
Examples:
|
22
38
|
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
23
39
|
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
|
@@ -26,7 +42,16 @@ class DetectionValidator(BaseValidator):
|
|
26
42
|
"""
|
27
43
|
|
28
44
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
29
|
-
"""
|
45
|
+
"""
|
46
|
+
Initialize detection validator with necessary variables and settings.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
50
|
+
save_dir (Path, optional): Directory to save results.
|
51
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
52
|
+
args (Dict, optional): Arguments for the validator.
|
53
|
+
_callbacks (List, optional): List of callback functions.
|
54
|
+
"""
|
30
55
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
31
56
|
self.nt_per_class = None
|
32
57
|
self.nt_per_image = None
|
@@ -45,7 +70,15 @@ class DetectionValidator(BaseValidator):
|
|
45
70
|
)
|
46
71
|
|
47
72
|
def preprocess(self, batch):
|
48
|
-
"""
|
73
|
+
"""
|
74
|
+
Preprocess batch of images for YOLO validation.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
batch (Dict): Batch containing images and annotations.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
(Dict): Preprocessed batch.
|
81
|
+
"""
|
49
82
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
50
83
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
51
84
|
for k in ["batch_idx", "cls", "bboxes"]:
|
@@ -63,7 +96,12 @@ class DetectionValidator(BaseValidator):
|
|
63
96
|
return batch
|
64
97
|
|
65
98
|
def init_metrics(self, model):
|
66
|
-
"""
|
99
|
+
"""
|
100
|
+
Initialize evaluation metrics for YOLO detection validation.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
model (torch.nn.Module): Model to validate.
|
104
|
+
"""
|
67
105
|
val = self.data.get(self.args.split, "") # validation path
|
68
106
|
self.is_coco = (
|
69
107
|
isinstance(val, str)
|
@@ -88,7 +126,15 @@ class DetectionValidator(BaseValidator):
|
|
88
126
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
89
127
|
|
90
128
|
def postprocess(self, preds):
|
91
|
-
"""
|
129
|
+
"""
|
130
|
+
Apply Non-maximum suppression to prediction outputs.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
preds (torch.Tensor): Raw predictions from the model.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
(List[torch.Tensor]): Processed predictions after NMS.
|
137
|
+
"""
|
92
138
|
return ops.non_max_suppression(
|
93
139
|
preds,
|
94
140
|
self.args.conf,
|
@@ -103,7 +149,16 @@ class DetectionValidator(BaseValidator):
|
|
103
149
|
)
|
104
150
|
|
105
151
|
def _prepare_batch(self, si, batch):
|
106
|
-
"""
|
152
|
+
"""
|
153
|
+
Prepare a batch of images and annotations for validation.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
si (int): Batch index.
|
157
|
+
batch (Dict): Batch data containing images and annotations.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
(Dict): Prepared batch with processed annotations.
|
161
|
+
"""
|
107
162
|
idx = batch["batch_idx"] == si
|
108
163
|
cls = batch["cls"][idx].squeeze(-1)
|
109
164
|
bbox = batch["bboxes"][idx]
|
@@ -116,7 +171,16 @@ class DetectionValidator(BaseValidator):
|
|
116
171
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
117
172
|
|
118
173
|
def _prepare_pred(self, pred, pbatch):
|
119
|
-
"""
|
174
|
+
"""
|
175
|
+
Prepare predictions for evaluation against ground truth.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
pred (torch.Tensor): Model predictions.
|
179
|
+
pbatch (Dict): Prepared batch information.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
(torch.Tensor): Prepared predictions in native space.
|
183
|
+
"""
|
120
184
|
predn = pred.clone()
|
121
185
|
ops.scale_boxes(
|
122
186
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
@@ -124,7 +188,13 @@ class DetectionValidator(BaseValidator):
|
|
124
188
|
return predn
|
125
189
|
|
126
190
|
def update_metrics(self, preds, batch):
|
127
|
-
"""
|
191
|
+
"""
|
192
|
+
Update metrics with new predictions and ground truth.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
preds (List[torch.Tensor]): List of predictions from the model.
|
196
|
+
batch (Dict): Batch data containing ground truth.
|
197
|
+
"""
|
128
198
|
for si, pred in enumerate(preds):
|
129
199
|
self.seen += 1
|
130
200
|
npr = len(pred)
|
@@ -173,12 +243,23 @@ class DetectionValidator(BaseValidator):
|
|
173
243
|
)
|
174
244
|
|
175
245
|
def finalize_metrics(self, *args, **kwargs):
|
176
|
-
"""
|
246
|
+
"""
|
247
|
+
Set final values for metrics speed and confusion matrix.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
*args (Any): Variable length argument list.
|
251
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
252
|
+
"""
|
177
253
|
self.metrics.speed = self.speed
|
178
254
|
self.metrics.confusion_matrix = self.confusion_matrix
|
179
255
|
|
180
256
|
def get_stats(self):
|
181
|
-
"""
|
257
|
+
"""
|
258
|
+
Calculate and return metrics statistics.
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
(Dict): Dictionary containing metrics results.
|
262
|
+
"""
|
182
263
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
183
264
|
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
|
184
265
|
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
|
@@ -188,7 +269,7 @@ class DetectionValidator(BaseValidator):
|
|
188
269
|
return self.metrics.results_dict
|
189
270
|
|
190
271
|
def print_results(self):
|
191
|
-
"""
|
272
|
+
"""Print training/validation set metrics per class."""
|
192
273
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
193
274
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
194
275
|
if self.nt_per_class.sum() == 0:
|
@@ -220,10 +301,6 @@ class DetectionValidator(BaseValidator):
|
|
220
301
|
|
221
302
|
Returns:
|
222
303
|
(torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
223
|
-
|
224
|
-
Note:
|
225
|
-
The function does not return any value directly usable for metrics calculation. Instead, it provides an
|
226
|
-
intermediate representation used for evaluating predictions against ground truth.
|
227
304
|
"""
|
228
305
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
229
306
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
@@ -235,17 +312,35 @@ class DetectionValidator(BaseValidator):
|
|
235
312
|
Args:
|
236
313
|
img_path (str): Path to the folder containing images.
|
237
314
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
238
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
315
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
(Dataset): YOLO dataset.
|
239
319
|
"""
|
240
320
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
241
321
|
|
242
322
|
def get_dataloader(self, dataset_path, batch_size):
|
243
|
-
"""
|
323
|
+
"""
|
324
|
+
Construct and return dataloader.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
dataset_path (str): Path to the dataset.
|
328
|
+
batch_size (int): Size of each batch.
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
(torch.utils.data.DataLoader): Dataloader for validation.
|
332
|
+
"""
|
244
333
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
245
334
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
246
335
|
|
247
336
|
def plot_val_samples(self, batch, ni):
|
248
|
-
"""
|
337
|
+
"""
|
338
|
+
Plot validation image samples.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
batch (Dict): Batch containing images and annotations.
|
342
|
+
ni (int): Batch index.
|
343
|
+
"""
|
249
344
|
plot_images(
|
250
345
|
batch["img"],
|
251
346
|
batch["batch_idx"],
|
@@ -258,7 +353,14 @@ class DetectionValidator(BaseValidator):
|
|
258
353
|
)
|
259
354
|
|
260
355
|
def plot_predictions(self, batch, preds, ni):
|
261
|
-
"""
|
356
|
+
"""
|
357
|
+
Plot predicted bounding boxes on input images and save the result.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
batch (Dict): Batch containing images and annotations.
|
361
|
+
preds (List[torch.Tensor]): List of predictions from the model.
|
362
|
+
ni (int): Batch index.
|
363
|
+
"""
|
262
364
|
plot_images(
|
263
365
|
batch["img"],
|
264
366
|
*output_to_target(preds, max_det=self.args.max_det),
|
@@ -269,7 +371,15 @@ class DetectionValidator(BaseValidator):
|
|
269
371
|
) # pred
|
270
372
|
|
271
373
|
def save_one_txt(self, predn, save_conf, shape, file):
|
272
|
-
"""
|
374
|
+
"""
|
375
|
+
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
376
|
+
|
377
|
+
Args:
|
378
|
+
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
379
|
+
save_conf (bool): Whether to save confidence scores.
|
380
|
+
shape (tuple): Shape of the original image.
|
381
|
+
file (Path): File path to save the detections.
|
382
|
+
"""
|
273
383
|
from ultralytics.engine.results import Results
|
274
384
|
|
275
385
|
Results(
|
@@ -280,7 +390,13 @@ class DetectionValidator(BaseValidator):
|
|
280
390
|
).save_txt(file, save_conf=save_conf)
|
281
391
|
|
282
392
|
def pred_to_json(self, predn, filename):
|
283
|
-
"""
|
393
|
+
"""
|
394
|
+
Serialize YOLO predictions to COCO json format.
|
395
|
+
|
396
|
+
Args:
|
397
|
+
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
398
|
+
filename (str): Image filename.
|
399
|
+
"""
|
284
400
|
stem = Path(filename).stem
|
285
401
|
image_id = int(stem) if stem.isnumeric() else stem
|
286
402
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -296,7 +412,15 @@ class DetectionValidator(BaseValidator):
|
|
296
412
|
)
|
297
413
|
|
298
414
|
def eval_json(self, stats):
|
299
|
-
"""
|
415
|
+
"""
|
416
|
+
Evaluate YOLO output in JSON format and return performance statistics.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
stats (Dict): Current statistics dictionary.
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
(Dict): Updated statistics dictionary with COCO/LVIS evaluation results.
|
423
|
+
"""
|
300
424
|
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
301
425
|
pred_json = self.save_dir / "predictions.json" # predictions
|
302
426
|
anno_json = (
|