dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
"""Provides utility to combine a vision backbone with a language backbone."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from copy import copy
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
14
|
+
|
|
15
|
+
from .necks import Sam3DualViTDetNeck
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SAM3VLBackbone(nn.Module):
|
|
19
|
+
"""This backbone combines a vision backbone and a language backbone without fusion. As such it is more of a
|
|
20
|
+
convenience wrapper to handle the two backbones together.
|
|
21
|
+
|
|
22
|
+
It adds support for activation checkpointing and compilation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
visual: Sam3DualViTDetNeck,
|
|
28
|
+
text,
|
|
29
|
+
compile_visual: bool = False,
|
|
30
|
+
act_ckpt_whole_vision_backbone: bool = False,
|
|
31
|
+
act_ckpt_whole_language_backbone: bool = False,
|
|
32
|
+
scalp=0,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize the backbone combiner.
|
|
35
|
+
|
|
36
|
+
:param visual: The vision backbone to use
|
|
37
|
+
:param text: The text encoder to use
|
|
38
|
+
"""
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.vision_backbone: Sam3DualViTDetNeck = torch.compile(visual) if compile_visual else visual
|
|
41
|
+
self.language_backbone = text
|
|
42
|
+
self.scalp = scalp
|
|
43
|
+
# allow running activation checkpointing on the entire vision and language backbones
|
|
44
|
+
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
|
|
45
|
+
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
|
|
46
|
+
|
|
47
|
+
def forward(
|
|
48
|
+
self,
|
|
49
|
+
samples: torch.Tensor,
|
|
50
|
+
captions: list[str],
|
|
51
|
+
input_boxes: torch.Tensor = None,
|
|
52
|
+
additional_text: list[str] | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""Forward pass of the backbone combiner.
|
|
55
|
+
|
|
56
|
+
:param samples: The input images
|
|
57
|
+
:param captions: The input captions
|
|
58
|
+
:param input_boxes: If the text contains place-holders for boxes, this
|
|
59
|
+
parameter contains the tensor containing their spatial features
|
|
60
|
+
:param additional_text: This can be used to encode some additional text
|
|
61
|
+
(different from the captions) in the same forward of the backbone
|
|
62
|
+
:return: Output dictionary with the following keys:
|
|
63
|
+
- vision_features: The output of the vision backbone
|
|
64
|
+
- language_features: The output of the language backbone
|
|
65
|
+
- language_mask: The attention mask of the language backbone
|
|
66
|
+
- vision_pos_enc: The positional encoding of the vision backbone
|
|
67
|
+
- (optional) additional_text_features: The output of the language
|
|
68
|
+
backbone for the additional text
|
|
69
|
+
- (optional) additional_text_mask: The attention mask of the
|
|
70
|
+
language backbone for the additional text
|
|
71
|
+
"""
|
|
72
|
+
output = self.forward_image(samples)
|
|
73
|
+
output.update(self.forward_text(captions, input_boxes, additional_text))
|
|
74
|
+
return output
|
|
75
|
+
|
|
76
|
+
def forward_image(self, samples: torch.Tensor):
|
|
77
|
+
"""Forward pass of the vision backbone and get both SAM3 and SAM2 features."""
|
|
78
|
+
# Forward through backbone
|
|
79
|
+
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(samples)
|
|
80
|
+
if self.scalp > 0:
|
|
81
|
+
# Discard the lowest resolution features
|
|
82
|
+
sam3_features, sam3_pos = (
|
|
83
|
+
sam3_features[: -self.scalp],
|
|
84
|
+
sam3_pos[: -self.scalp],
|
|
85
|
+
)
|
|
86
|
+
if sam2_features is not None and sam2_pos is not None:
|
|
87
|
+
sam2_features, sam2_pos = (
|
|
88
|
+
sam2_features[: -self.scalp],
|
|
89
|
+
sam2_pos[: -self.scalp],
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
sam2_output = None
|
|
93
|
+
|
|
94
|
+
if sam2_features is not None and sam2_pos is not None:
|
|
95
|
+
sam2_src = sam2_features[-1]
|
|
96
|
+
sam2_output = {
|
|
97
|
+
"vision_features": sam2_src,
|
|
98
|
+
"vision_pos_enc": sam2_pos,
|
|
99
|
+
"backbone_fpn": sam2_features,
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
sam3_src = sam3_features[-1]
|
|
103
|
+
return {
|
|
104
|
+
"vision_features": sam3_src,
|
|
105
|
+
"vision_pos_enc": sam3_pos,
|
|
106
|
+
"backbone_fpn": sam3_features,
|
|
107
|
+
"sam2_backbone_out": sam2_output,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def forward_image_sam2(self, samples: torch.Tensor):
|
|
111
|
+
"""Forward pass of the vision backbone to get SAM2 features only."""
|
|
112
|
+
xs = self.vision_backbone.trunk(samples)
|
|
113
|
+
x = xs[-1] # simpleFPN
|
|
114
|
+
|
|
115
|
+
assert self.vision_backbone.sam2_convs is not None, "SAM2 neck is not available."
|
|
116
|
+
sam2_features, sam2_pos = self.vision_backbone.sam_forward_feature_levels(x, self.vision_backbone.sam2_convs)
|
|
117
|
+
|
|
118
|
+
if self.scalp > 0:
|
|
119
|
+
# Discard the lowest resolution features
|
|
120
|
+
sam2_features, sam2_pos = (
|
|
121
|
+
sam2_features[: -self.scalp],
|
|
122
|
+
sam2_pos[: -self.scalp],
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return {
|
|
126
|
+
"vision_features": sam2_features[-1],
|
|
127
|
+
"vision_pos_enc": sam2_pos,
|
|
128
|
+
"backbone_fpn": sam2_features,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
def forward_text(self, captions, input_boxes=None, additional_text=None):
|
|
132
|
+
"""Forward pass of the text encoder."""
|
|
133
|
+
output = {}
|
|
134
|
+
|
|
135
|
+
# Forward through text_encoder
|
|
136
|
+
text_to_encode = copy(captions)
|
|
137
|
+
if additional_text is not None:
|
|
138
|
+
# if there are additional_text, we piggy-back them into this forward.
|
|
139
|
+
# They'll be used later for output alignment
|
|
140
|
+
text_to_encode += additional_text
|
|
141
|
+
|
|
142
|
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]):
|
|
143
|
+
text_attention_mask, text_memory, text_embeds = self.language_backbone(text_to_encode, input_boxes)
|
|
144
|
+
|
|
145
|
+
if additional_text is not None:
|
|
146
|
+
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
|
|
147
|
+
output["additional_text_mask"] = text_attention_mask[-len(additional_text) :]
|
|
148
|
+
|
|
149
|
+
text_memory = text_memory[:, : len(captions)]
|
|
150
|
+
text_attention_mask = text_attention_mask[: len(captions)]
|
|
151
|
+
text_embeds = text_embeds[:, : len(captions)]
|
|
152
|
+
output["language_features"] = text_memory
|
|
153
|
+
output["language_mask"] = text_attention_mask
|
|
154
|
+
output["language_embeds"] = text_embeds # Text embeddings before forward to the encoder
|
|
155
|
+
|
|
156
|
+
return output
|
|
157
|
+
|
|
158
|
+
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
|
159
|
+
"""Set the image size for the vision backbone."""
|
|
160
|
+
self.vision_backbone.set_imgsz(imgsz)
|
ultralytics/models/utils/loss.py
CHANGED
|
@@ -15,11 +15,10 @@ from .ops import HungarianMatcher
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class DETRLoss(nn.Module):
|
|
18
|
-
"""
|
|
19
|
-
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
18
|
+
"""DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
20
19
|
|
|
21
|
-
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
|
22
|
-
|
|
20
|
+
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
|
|
21
|
+
object detection model.
|
|
23
22
|
|
|
24
23
|
Attributes:
|
|
25
24
|
nc (int): Number of classes.
|
|
@@ -47,8 +46,7 @@ class DETRLoss(nn.Module):
|
|
|
47
46
|
gamma: float = 1.5,
|
|
48
47
|
alpha: float = 0.25,
|
|
49
48
|
):
|
|
50
|
-
"""
|
|
51
|
-
Initialize DETR loss function with customizable components and gains.
|
|
49
|
+
"""Initialize DETR loss function with customizable components and gains.
|
|
52
50
|
|
|
53
51
|
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
|
54
52
|
losses and various loss types.
|
|
@@ -82,8 +80,7 @@ class DETRLoss(nn.Module):
|
|
|
82
80
|
def _get_loss_class(
|
|
83
81
|
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
|
84
82
|
) -> dict[str, torch.Tensor]:
|
|
85
|
-
"""
|
|
86
|
-
Compute classification loss based on predictions, target values, and ground truth scores.
|
|
83
|
+
"""Compute classification loss based on predictions, target values, and ground truth scores.
|
|
87
84
|
|
|
88
85
|
Args:
|
|
89
86
|
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
|
@@ -124,8 +121,7 @@ class DETRLoss(nn.Module):
|
|
|
124
121
|
def _get_loss_bbox(
|
|
125
122
|
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
|
126
123
|
) -> dict[str, torch.Tensor]:
|
|
127
|
-
"""
|
|
128
|
-
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
|
124
|
+
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
|
129
125
|
|
|
130
126
|
Args:
|
|
131
127
|
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
|
@@ -199,8 +195,7 @@ class DETRLoss(nn.Module):
|
|
|
199
195
|
masks: torch.Tensor | None = None,
|
|
200
196
|
gt_mask: torch.Tensor | None = None,
|
|
201
197
|
) -> dict[str, torch.Tensor]:
|
|
202
|
-
"""
|
|
203
|
-
Get auxiliary losses for intermediate decoder layers.
|
|
198
|
+
"""Get auxiliary losses for intermediate decoder layers.
|
|
204
199
|
|
|
205
200
|
Args:
|
|
206
201
|
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
|
@@ -261,8 +256,7 @@ class DETRLoss(nn.Module):
|
|
|
261
256
|
|
|
262
257
|
@staticmethod
|
|
263
258
|
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
264
|
-
"""
|
|
265
|
-
Extract batch indices, source indices, and destination indices from match indices.
|
|
259
|
+
"""Extract batch indices, source indices, and destination indices from match indices.
|
|
266
260
|
|
|
267
261
|
Args:
|
|
268
262
|
match_indices (list[tuple]): List of tuples containing matched indices.
|
|
@@ -279,8 +273,7 @@ class DETRLoss(nn.Module):
|
|
|
279
273
|
def _get_assigned_bboxes(
|
|
280
274
|
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
|
281
275
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
282
|
-
"""
|
|
283
|
-
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
|
276
|
+
"""Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
|
284
277
|
|
|
285
278
|
Args:
|
|
286
279
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
@@ -317,8 +310,7 @@ class DETRLoss(nn.Module):
|
|
|
317
310
|
postfix: str = "",
|
|
318
311
|
match_indices: list[tuple] | None = None,
|
|
319
312
|
) -> dict[str, torch.Tensor]:
|
|
320
|
-
"""
|
|
321
|
-
Calculate losses for a single prediction layer.
|
|
313
|
+
"""Calculate losses for a single prediction layer.
|
|
322
314
|
|
|
323
315
|
Args:
|
|
324
316
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
@@ -364,8 +356,7 @@ class DETRLoss(nn.Module):
|
|
|
364
356
|
postfix: str = "",
|
|
365
357
|
**kwargs: Any,
|
|
366
358
|
) -> dict[str, torch.Tensor]:
|
|
367
|
-
"""
|
|
368
|
-
Calculate loss for predicted bounding boxes and scores.
|
|
359
|
+
"""Calculate loss for predicted bounding boxes and scores.
|
|
369
360
|
|
|
370
361
|
Args:
|
|
371
362
|
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
|
@@ -400,8 +391,7 @@ class DETRLoss(nn.Module):
|
|
|
400
391
|
|
|
401
392
|
|
|
402
393
|
class RTDETRDetectionLoss(DETRLoss):
|
|
403
|
-
"""
|
|
404
|
-
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
394
|
+
"""Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
405
395
|
|
|
406
396
|
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
|
407
397
|
an additional denoising training loss when provided with denoising metadata.
|
|
@@ -415,8 +405,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
415
405
|
dn_scores: torch.Tensor | None = None,
|
|
416
406
|
dn_meta: dict[str, Any] | None = None,
|
|
417
407
|
) -> dict[str, torch.Tensor]:
|
|
418
|
-
"""
|
|
419
|
-
Forward pass to compute detection loss with optional denoising loss.
|
|
408
|
+
"""Forward pass to compute detection loss with optional denoising loss.
|
|
420
409
|
|
|
421
410
|
Args:
|
|
422
411
|
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
|
@@ -452,8 +441,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
452
441
|
def get_dn_match_indices(
|
|
453
442
|
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
|
454
443
|
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
455
|
-
"""
|
|
456
|
-
Get match indices for denoising.
|
|
444
|
+
"""Get match indices for denoising.
|
|
457
445
|
|
|
458
446
|
Args:
|
|
459
447
|
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
ultralytics/models/utils/ops.py
CHANGED
|
@@ -14,8 +14,7 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class HungarianMatcher(nn.Module):
|
|
17
|
-
"""
|
|
18
|
-
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
17
|
+
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
19
18
|
|
|
20
19
|
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
|
21
20
|
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
|
@@ -56,8 +55,7 @@ class HungarianMatcher(nn.Module):
|
|
|
56
55
|
alpha: float = 0.25,
|
|
57
56
|
gamma: float = 2.0,
|
|
58
57
|
):
|
|
59
|
-
"""
|
|
60
|
-
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
58
|
+
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
61
|
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
|
@@ -88,8 +86,7 @@ class HungarianMatcher(nn.Module):
|
|
|
88
86
|
masks: torch.Tensor | None = None,
|
|
89
87
|
gt_mask: list[torch.Tensor] | None = None,
|
|
90
88
|
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
91
|
-
"""
|
|
92
|
-
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
89
|
+
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
93
90
|
|
|
94
91
|
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
|
95
92
|
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
|
@@ -105,10 +102,10 @@ class HungarianMatcher(nn.Module):
|
|
|
105
102
|
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
|
106
103
|
|
|
107
104
|
Returns:
|
|
108
|
-
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
105
|
+
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
|
|
106
|
+
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
|
|
107
|
+
the tensor of indices of the corresponding selected ground truth targets (in order).
|
|
108
|
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
|
112
109
|
"""
|
|
113
110
|
bs, nq, nc = pred_scores.shape
|
|
114
111
|
|
|
@@ -198,16 +195,15 @@ def get_cdn_group(
|
|
|
198
195
|
box_noise_scale: float = 1.0,
|
|
199
196
|
training: bool = False,
|
|
200
197
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
|
201
|
-
"""
|
|
202
|
-
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
198
|
+
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
203
199
|
|
|
204
|
-
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
|
|
205
|
-
|
|
200
|
+
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
|
|
201
|
+
boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
|
206
202
|
|
|
207
203
|
Args:
|
|
208
|
-
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
|
209
|
-
|
|
210
|
-
|
|
204
|
+
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
|
|
205
|
+
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
|
|
206
|
+
per image.
|
|
211
207
|
num_classes (int): Total number of object classes.
|
|
212
208
|
num_queries (int): Number of object queries.
|
|
213
209
|
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
|
@@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world,
|
|
|
4
4
|
|
|
5
5
|
from .model import YOLO, YOLOE, YOLOWorld
|
|
6
6
|
|
|
7
|
-
__all__ = "
|
|
7
|
+
__all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
|
|
@@ -11,11 +11,10 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ClassificationPredictor(BasePredictor):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the BasePredictor class for prediction based on a classification model.
|
|
14
|
+
"""A class extending the BasePredictor class for prediction based on a classification model.
|
|
16
15
|
|
|
17
|
-
This predictor handles the specific requirements of classification models, including preprocessing images
|
|
18
|
-
|
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images and
|
|
17
|
+
postprocessing predictions to generate classification results.
|
|
19
18
|
|
|
20
19
|
Attributes:
|
|
21
20
|
args (dict): Configuration arguments for the predictor.
|
|
@@ -24,20 +23,19 @@ class ClassificationPredictor(BasePredictor):
|
|
|
24
23
|
preprocess: Convert input images to model-compatible format.
|
|
25
24
|
postprocess: Process model predictions into Results objects.
|
|
26
25
|
|
|
27
|
-
Notes:
|
|
28
|
-
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
29
|
-
|
|
30
26
|
Examples:
|
|
31
27
|
>>> from ultralytics.utils import ASSETS
|
|
32
28
|
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
|
33
29
|
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
|
34
30
|
>>> predictor = ClassificationPredictor(overrides=args)
|
|
35
31
|
>>> predictor.predict_cli()
|
|
32
|
+
|
|
33
|
+
Notes:
|
|
34
|
+
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
36
35
|
"""
|
|
37
36
|
|
|
38
37
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
39
|
-
"""
|
|
40
|
-
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
38
|
+
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
41
39
|
|
|
42
40
|
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
|
43
41
|
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
|
@@ -72,8 +70,7 @@ class ClassificationPredictor(BasePredictor):
|
|
|
72
70
|
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
|
73
71
|
|
|
74
72
|
def postprocess(self, preds, img, orig_imgs):
|
|
75
|
-
"""
|
|
76
|
-
Process predictions to return Results objects with classification probabilities.
|
|
73
|
+
"""Process predictions to return Results objects with classification probabilities.
|
|
77
74
|
|
|
78
75
|
Args:
|
|
79
76
|
preds (torch.Tensor): Raw predictions from the model.
|
|
@@ -84,7 +81,7 @@ class ClassificationPredictor(BasePredictor):
|
|
|
84
81
|
(list[Results]): List of Results objects containing classification results for each image.
|
|
85
82
|
"""
|
|
86
83
|
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
|
87
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
84
|
+
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
|
|
88
85
|
|
|
89
86
|
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
|
90
87
|
return [
|
|
@@ -11,14 +11,13 @@ from ultralytics.data import ClassificationDataset, build_dataloader
|
|
|
11
11
|
from ultralytics.engine.trainer import BaseTrainer
|
|
12
12
|
from ultralytics.models import yolo
|
|
13
13
|
from ultralytics.nn.tasks import ClassificationModel
|
|
14
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
14
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
15
15
|
from ultralytics.utils.plotting import plot_images
|
|
16
|
-
from ultralytics.utils.torch_utils import is_parallel,
|
|
16
|
+
from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_first
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class ClassificationTrainer(BaseTrainer):
|
|
20
|
-
"""
|
|
21
|
-
A trainer class extending BaseTrainer for training image classification models.
|
|
20
|
+
"""A trainer class extending BaseTrainer for training image classification models.
|
|
22
21
|
|
|
23
22
|
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
|
24
23
|
and torchvision models with comprehensive dataset handling and validation.
|
|
@@ -38,7 +37,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
38
37
|
preprocess_batch: Preprocess a batch of images and classes.
|
|
39
38
|
progress_string: Return a formatted string showing training progress.
|
|
40
39
|
get_validator: Return an instance of ClassificationValidator.
|
|
41
|
-
label_loss_items: Return a loss dict with
|
|
40
|
+
label_loss_items: Return a loss dict with labeled training loss items.
|
|
42
41
|
final_eval: Evaluate trained model and save validation results.
|
|
43
42
|
plot_training_samples: Plot training samples with their annotations.
|
|
44
43
|
|
|
@@ -51,8 +50,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
51
50
|
"""
|
|
52
51
|
|
|
53
52
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
54
|
-
"""
|
|
55
|
-
Initialize a ClassificationTrainer object.
|
|
53
|
+
"""Initialize a ClassificationTrainer object.
|
|
56
54
|
|
|
57
55
|
Args:
|
|
58
56
|
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
|
@@ -71,8 +69,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
71
69
|
self.model.names = self.data["names"]
|
|
72
70
|
|
|
73
71
|
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
74
|
-
"""
|
|
75
|
-
Return a modified PyTorch model configured for training YOLO classification.
|
|
72
|
+
"""Return a modified PyTorch model configured for training YOLO classification.
|
|
76
73
|
|
|
77
74
|
Args:
|
|
78
75
|
cfg (Any, optional): Model configuration.
|
|
@@ -96,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
96
93
|
return model
|
|
97
94
|
|
|
98
95
|
def setup_model(self):
|
|
99
|
-
"""
|
|
100
|
-
Load, create or download model for classification tasks.
|
|
96
|
+
"""Load, create or download model for classification tasks.
|
|
101
97
|
|
|
102
98
|
Returns:
|
|
103
99
|
(Any): Model checkpoint if applicable, otherwise None.
|
|
@@ -115,8 +111,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
115
111
|
return ckpt
|
|
116
112
|
|
|
117
113
|
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
118
|
-
"""
|
|
119
|
-
Create a ClassificationDataset instance given an image path and mode.
|
|
114
|
+
"""Create a ClassificationDataset instance given an image path and mode.
|
|
120
115
|
|
|
121
116
|
Args:
|
|
122
117
|
img_path (str): Path to the dataset images.
|
|
@@ -129,8 +124,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
129
124
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
130
125
|
|
|
131
126
|
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
|
132
|
-
"""
|
|
133
|
-
Return PyTorch DataLoader with transforms to preprocess images.
|
|
127
|
+
"""Return PyTorch DataLoader with transforms to preprocess images.
|
|
134
128
|
|
|
135
129
|
Args:
|
|
136
130
|
dataset_path (str): Path to the dataset.
|
|
@@ -177,8 +171,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
177
171
|
)
|
|
178
172
|
|
|
179
173
|
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
|
180
|
-
"""
|
|
181
|
-
Return a loss dict with labelled training loss items tensor.
|
|
174
|
+
"""Return a loss dict with labeled training loss items tensor.
|
|
182
175
|
|
|
183
176
|
Args:
|
|
184
177
|
loss_items (torch.Tensor, optional): Loss tensor items.
|
|
@@ -194,22 +187,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
194
187
|
loss_items = [round(float(loss_items), 5)]
|
|
195
188
|
return dict(zip(keys, loss_items))
|
|
196
189
|
|
|
197
|
-
def final_eval(self):
|
|
198
|
-
"""Evaluate trained model and save validation results."""
|
|
199
|
-
for f in self.last, self.best:
|
|
200
|
-
if f.exists():
|
|
201
|
-
strip_optimizer(f) # strip optimizers
|
|
202
|
-
if f is self.best:
|
|
203
|
-
LOGGER.info(f"\nValidating {f}...")
|
|
204
|
-
self.validator.args.data = self.args.data
|
|
205
|
-
self.validator.args.plots = self.args.plots
|
|
206
|
-
self.metrics = self.validator(model=f)
|
|
207
|
-
self.metrics.pop("fitness", None)
|
|
208
|
-
self.run_callbacks("on_fit_epoch_end")
|
|
209
|
-
|
|
210
190
|
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
|
211
|
-
"""
|
|
212
|
-
Plot training samples with their annotations.
|
|
191
|
+
"""Plot training samples with their annotations.
|
|
213
192
|
|
|
214
193
|
Args:
|
|
215
194
|
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
|
@@ -6,20 +6,20 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
9
10
|
|
|
10
11
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
11
12
|
from ultralytics.engine.validator import BaseValidator
|
|
12
|
-
from ultralytics.utils import LOGGER
|
|
13
|
+
from ultralytics.utils import LOGGER, RANK
|
|
13
14
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
|
14
15
|
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class ClassificationValidator(BaseValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the BaseValidator class for validation based on a classification model.
|
|
19
|
+
"""A class extending the BaseValidator class for validation based on a classification model.
|
|
20
20
|
|
|
21
|
-
This validator handles the validation process for classification models, including metrics calculation,
|
|
22
|
-
|
|
21
|
+
This validator handles the validation process for classification models, including metrics calculation, confusion
|
|
22
|
+
matrix generation, and visualization of results.
|
|
23
23
|
|
|
24
24
|
Attributes:
|
|
25
25
|
targets (list[torch.Tensor]): Ground truth class labels.
|
|
@@ -54,20 +54,13 @@ class ClassificationValidator(BaseValidator):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
57
|
-
"""
|
|
58
|
-
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
57
|
+
"""Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
59
58
|
|
|
60
59
|
Args:
|
|
61
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
60
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
62
61
|
save_dir (str | Path, optional): Directory to save results.
|
|
63
62
|
args (dict, optional): Arguments containing model and validation configuration.
|
|
64
63
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
65
|
-
|
|
66
|
-
Examples:
|
|
67
|
-
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
68
|
-
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
|
69
|
-
>>> validator = ClassificationValidator(args=args)
|
|
70
|
-
>>> validator()
|
|
71
64
|
"""
|
|
72
65
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
73
66
|
self.targets = None
|
|
@@ -95,8 +88,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
95
88
|
return batch
|
|
96
89
|
|
|
97
90
|
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
|
98
|
-
"""
|
|
99
|
-
Update running metrics with model predictions and batch targets.
|
|
91
|
+
"""Update running metrics with model predictions and batch targets.
|
|
100
92
|
|
|
101
93
|
Args:
|
|
102
94
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
|
@@ -111,12 +103,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
111
103
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
|
112
104
|
|
|
113
105
|
def finalize_metrics(self) -> None:
|
|
114
|
-
"""
|
|
115
|
-
Finalize metrics including confusion matrix and processing speed.
|
|
116
|
-
|
|
117
|
-
Notes:
|
|
118
|
-
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
119
|
-
optionally plots it, and updates the metrics object with speed information.
|
|
106
|
+
"""Finalize metrics including confusion matrix and processing speed.
|
|
120
107
|
|
|
121
108
|
Examples:
|
|
122
109
|
>>> validator = ClassificationValidator()
|
|
@@ -124,6 +111,10 @@ class ClassificationValidator(BaseValidator):
|
|
|
124
111
|
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
|
125
112
|
>>> validator.finalize_metrics()
|
|
126
113
|
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
|
114
|
+
|
|
115
|
+
Notes:
|
|
116
|
+
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
117
|
+
optionally plots it, and updates the metrics object with speed information.
|
|
127
118
|
"""
|
|
128
119
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
|
129
120
|
if self.args.plots:
|
|
@@ -142,13 +133,25 @@ class ClassificationValidator(BaseValidator):
|
|
|
142
133
|
self.metrics.process(self.targets, self.pred)
|
|
143
134
|
return self.metrics.results_dict
|
|
144
135
|
|
|
136
|
+
def gather_stats(self) -> None:
|
|
137
|
+
"""Gather stats from all GPUs."""
|
|
138
|
+
if RANK == 0:
|
|
139
|
+
gathered_preds = [None] * dist.get_world_size()
|
|
140
|
+
gathered_targets = [None] * dist.get_world_size()
|
|
141
|
+
dist.gather_object(self.pred, gathered_preds, dst=0)
|
|
142
|
+
dist.gather_object(self.targets, gathered_targets, dst=0)
|
|
143
|
+
self.pred = [pred for rank in gathered_preds for pred in rank]
|
|
144
|
+
self.targets = [targets for rank in gathered_targets for targets in rank]
|
|
145
|
+
elif RANK > 0:
|
|
146
|
+
dist.gather_object(self.pred, None, dst=0)
|
|
147
|
+
dist.gather_object(self.targets, None, dst=0)
|
|
148
|
+
|
|
145
149
|
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
|
146
150
|
"""Create a ClassificationDataset instance for validation."""
|
|
147
151
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
148
152
|
|
|
149
153
|
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
150
|
-
"""
|
|
151
|
-
Build and return a data loader for classification validation.
|
|
154
|
+
"""Build and return a data loader for classification validation.
|
|
152
155
|
|
|
153
156
|
Args:
|
|
154
157
|
dataset_path (str | Path): Path to the dataset directory.
|
|
@@ -166,8 +169,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
166
169
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
|
167
170
|
|
|
168
171
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Plot validation image samples with their ground truth labels.
|
|
172
|
+
"""Plot validation image samples with their ground truth labels.
|
|
171
173
|
|
|
172
174
|
Args:
|
|
173
175
|
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
@@ -187,8 +189,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
187
189
|
)
|
|
188
190
|
|
|
189
191
|
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
|
190
|
-
"""
|
|
191
|
-
Plot images with their predicted class labels and save the visualization.
|
|
192
|
+
"""Plot images with their predicted class labels and save the visualization.
|
|
192
193
|
|
|
193
194
|
Args:
|
|
194
195
|
batch (dict[str, Any]): Batch data containing images and other information.
|