dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/models/utils/ops.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
3
7
|
import torch
|
|
4
8
|
import torch.nn as nn
|
|
5
9
|
import torch.nn.functional as F
|
|
@@ -10,41 +14,57 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
class HungarianMatcher(nn.Module):
|
|
13
|
-
"""
|
|
14
|
-
A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
|
|
15
|
-
end-to-end fashion.
|
|
17
|
+
"""A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
|
16
18
|
|
|
17
|
-
HungarianMatcher performs optimal assignment over
|
|
18
|
-
function that considers classification scores, bounding box coordinates, and optionally
|
|
19
|
+
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
|
20
|
+
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
|
21
|
+
used in end-to-end object detection models like DETR.
|
|
19
22
|
|
|
20
23
|
Attributes:
|
|
21
|
-
cost_gain (dict): Dictionary of cost coefficients
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
24
|
+
cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
|
|
25
|
+
components.
|
|
26
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
|
27
|
+
with_mask (bool): Whether the model makes mask predictions.
|
|
28
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
|
29
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
|
30
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
|
27
31
|
|
|
28
32
|
Methods:
|
|
29
|
-
forward:
|
|
30
|
-
_cost_mask:
|
|
33
|
+
forward: Compute optimal assignment between predictions and ground truths for a batch.
|
|
34
|
+
_cost_mask: Compute mask cost and dice cost if masks are predicted.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
Initialize a HungarianMatcher with custom cost gains
|
|
38
|
+
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
|
39
|
+
|
|
40
|
+
Perform matching between predictions and ground truth
|
|
41
|
+
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
|
|
42
|
+
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
|
|
43
|
+
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
|
|
44
|
+
>>> gt_classes = torch.randint(0, 80, (10,))
|
|
45
|
+
>>> gt_groups = [5, 5] # 5 GT boxes per image
|
|
46
|
+
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
|
|
31
47
|
"""
|
|
32
48
|
|
|
33
|
-
def __init__(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
cost_gain: dict[str, float] | None = None,
|
|
52
|
+
use_fl: bool = True,
|
|
53
|
+
with_mask: bool = False,
|
|
54
|
+
num_sample_points: int = 12544,
|
|
55
|
+
alpha: float = 0.25,
|
|
56
|
+
gamma: float = 2.0,
|
|
57
|
+
):
|
|
58
|
+
"""Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
|
39
59
|
|
|
40
60
|
Args:
|
|
41
|
-
cost_gain (dict, optional): Dictionary of cost coefficients for different
|
|
42
|
-
Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
|
43
|
-
use_fl (bool
|
|
44
|
-
with_mask (bool
|
|
45
|
-
num_sample_points (int
|
|
46
|
-
alpha (float
|
|
47
|
-
gamma (float
|
|
61
|
+
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
|
62
|
+
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
|
63
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
|
64
|
+
with_mask (bool): Whether the model makes mask predictions.
|
|
65
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
|
66
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
|
67
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
|
48
68
|
"""
|
|
49
69
|
super().__init__()
|
|
50
70
|
if cost_gain is None:
|
|
@@ -56,41 +76,48 @@ class HungarianMatcher(nn.Module):
|
|
|
56
76
|
self.alpha = alpha
|
|
57
77
|
self.gamma = gamma
|
|
58
78
|
|
|
59
|
-
def forward(
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
79
|
+
def forward(
|
|
80
|
+
self,
|
|
81
|
+
pred_bboxes: torch.Tensor,
|
|
82
|
+
pred_scores: torch.Tensor,
|
|
83
|
+
gt_bboxes: torch.Tensor,
|
|
84
|
+
gt_cls: torch.Tensor,
|
|
85
|
+
gt_groups: list[int],
|
|
86
|
+
masks: torch.Tensor | None = None,
|
|
87
|
+
gt_mask: list[torch.Tensor] | None = None,
|
|
88
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
89
|
+
"""Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
|
90
|
+
|
|
91
|
+
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
|
92
|
+
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
|
63
93
|
|
|
64
94
|
Args:
|
|
65
95
|
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
|
66
|
-
pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries,
|
|
67
|
-
|
|
96
|
+
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
|
|
97
|
+
num_classes).
|
|
68
98
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
|
69
|
-
|
|
70
|
-
|
|
99
|
+
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
|
|
100
|
+
gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
|
|
71
101
|
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
|
72
|
-
gt_mask (
|
|
102
|
+
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
|
73
103
|
|
|
74
104
|
Returns:
|
|
75
|
-
(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
|
105
|
+
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
|
|
106
|
+
index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
|
|
107
|
+
the tensor of indices of the corresponding selected ground truth targets (in order).
|
|
108
|
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
|
80
109
|
"""
|
|
81
110
|
bs, nq, nc = pred_scores.shape
|
|
82
111
|
|
|
83
112
|
if sum(gt_groups) == 0:
|
|
84
113
|
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
|
85
114
|
|
|
86
|
-
#
|
|
87
|
-
# (batch_size * num_queries, num_classes)
|
|
115
|
+
# Flatten to compute cost matrices in batch format
|
|
88
116
|
pred_scores = pred_scores.detach().view(-1, nc)
|
|
89
117
|
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
|
90
|
-
# (batch_size * num_queries, 4)
|
|
91
118
|
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
|
92
119
|
|
|
93
|
-
# Compute
|
|
120
|
+
# Compute classification cost
|
|
94
121
|
pred_scores = pred_scores[:, gt_cls]
|
|
95
122
|
if self.use_fl:
|
|
96
123
|
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
|
@@ -99,23 +126,24 @@ class HungarianMatcher(nn.Module):
|
|
|
99
126
|
else:
|
|
100
127
|
cost_class = -pred_scores
|
|
101
128
|
|
|
102
|
-
# Compute
|
|
129
|
+
# Compute L1 cost between boxes
|
|
103
130
|
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
|
104
131
|
|
|
105
|
-
# Compute
|
|
132
|
+
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
|
|
106
133
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
|
107
134
|
|
|
108
|
-
#
|
|
135
|
+
# Combine costs into final cost matrix
|
|
109
136
|
C = (
|
|
110
137
|
self.cost_gain["class"] * cost_class
|
|
111
138
|
+ self.cost_gain["bbox"] * cost_bbox
|
|
112
139
|
+ self.cost_gain["giou"] * cost_giou
|
|
113
140
|
)
|
|
114
|
-
|
|
141
|
+
|
|
142
|
+
# Add mask costs if available
|
|
115
143
|
if self.with_mask:
|
|
116
144
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
|
117
145
|
|
|
118
|
-
# Set invalid values (NaNs and infinities) to 0
|
|
146
|
+
# Set invalid values (NaNs and infinities) to 0
|
|
119
147
|
C[C.isnan() | C.isinf()] = 0.0
|
|
120
148
|
|
|
121
149
|
C = C.view(bs, nq, -1).cpu()
|
|
@@ -158,28 +186,48 @@ class HungarianMatcher(nn.Module):
|
|
|
158
186
|
|
|
159
187
|
|
|
160
188
|
def get_cdn_group(
|
|
161
|
-
batch
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
189
|
+
batch: dict[str, Any],
|
|
190
|
+
num_classes: int,
|
|
191
|
+
num_queries: int,
|
|
192
|
+
class_embed: torch.Tensor,
|
|
193
|
+
num_dn: int = 100,
|
|
194
|
+
cls_noise_ratio: float = 0.5,
|
|
195
|
+
box_noise_scale: float = 1.0,
|
|
196
|
+
training: bool = False,
|
|
197
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
|
198
|
+
"""Generate contrastive denoising training group with positive and negative samples from ground truths.
|
|
199
|
+
|
|
200
|
+
This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
|
|
201
|
+
boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
|
165
202
|
|
|
166
203
|
Args:
|
|
167
|
-
batch (dict):
|
|
168
|
-
(torch.Tensor with shape (num_gts, 4)), 'gt_groups' (
|
|
169
|
-
|
|
170
|
-
num_classes (int):
|
|
171
|
-
num_queries (int): Number of queries.
|
|
172
|
-
class_embed (torch.Tensor):
|
|
173
|
-
num_dn (int
|
|
174
|
-
cls_noise_ratio (float
|
|
175
|
-
box_noise_scale (float
|
|
176
|
-
training (bool
|
|
204
|
+
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
|
|
205
|
+
(torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
|
|
206
|
+
per image.
|
|
207
|
+
num_classes (int): Total number of object classes.
|
|
208
|
+
num_queries (int): Number of object queries.
|
|
209
|
+
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
|
210
|
+
num_dn (int): Number of denoising queries to generate.
|
|
211
|
+
cls_noise_ratio (float): Noise ratio for class labels.
|
|
212
|
+
box_noise_scale (float): Noise scale for bounding box coordinates.
|
|
213
|
+
training (bool): Whether model is in training mode.
|
|
177
214
|
|
|
178
215
|
Returns:
|
|
179
|
-
padding_cls (
|
|
180
|
-
padding_bbox (
|
|
181
|
-
attn_mask (
|
|
182
|
-
dn_meta (
|
|
216
|
+
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
|
|
217
|
+
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
|
|
218
|
+
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
|
|
219
|
+
dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
|
|
220
|
+
|
|
221
|
+
Examples:
|
|
222
|
+
Generate denoising group for training
|
|
223
|
+
>>> batch = {
|
|
224
|
+
... "cls": torch.tensor([0, 1, 2]),
|
|
225
|
+
... "bboxes": torch.rand(3, 4),
|
|
226
|
+
... "batch_idx": torch.tensor([0, 0, 1]),
|
|
227
|
+
... "gt_groups": [2, 1],
|
|
228
|
+
... }
|
|
229
|
+
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
|
|
230
|
+
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
|
|
183
231
|
"""
|
|
184
232
|
if (not training) or num_dn <= 0 or batch is None:
|
|
185
233
|
return None, None, None, None
|
|
@@ -197,7 +245,7 @@ def get_cdn_group(
|
|
|
197
245
|
gt_bbox = batch["bboxes"] # bs*num, 4
|
|
198
246
|
b_idx = batch["batch_idx"]
|
|
199
247
|
|
|
200
|
-
# Each group has positive and negative queries
|
|
248
|
+
# Each group has positive and negative queries
|
|
201
249
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
|
202
250
|
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
|
203
251
|
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
|
@@ -207,10 +255,10 @@ def get_cdn_group(
|
|
|
207
255
|
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
|
208
256
|
|
|
209
257
|
if cls_noise_ratio > 0:
|
|
210
|
-
#
|
|
258
|
+
# Apply class label noise to half of the samples
|
|
211
259
|
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
|
212
260
|
idx = torch.nonzero(mask).squeeze(-1)
|
|
213
|
-
# Randomly
|
|
261
|
+
# Randomly assign new class labels
|
|
214
262
|
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
|
215
263
|
dn_cls[idx] = new_label
|
|
216
264
|
|
|
@@ -229,7 +277,6 @@ def get_cdn_group(
|
|
|
229
277
|
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
|
230
278
|
|
|
231
279
|
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
|
232
|
-
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
|
233
280
|
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
|
234
281
|
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
|
235
282
|
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
|
@@ -4,4 +4,4 @@ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world,
|
|
|
4
4
|
|
|
5
5
|
from .model import YOLO, YOLOE, YOLOWorld
|
|
6
6
|
|
|
7
|
-
__all__ = "
|
|
7
|
+
__all__ = "YOLO", "YOLOE", "YOLOWorld", "classify", "detect", "obb", "pose", "segment", "world", "yoloe"
|
|
@@ -4,81 +4,83 @@ import cv2
|
|
|
4
4
|
import torch
|
|
5
5
|
from PIL import Image
|
|
6
6
|
|
|
7
|
+
from ultralytics.data.augment import classify_transforms
|
|
7
8
|
from ultralytics.engine.predictor import BasePredictor
|
|
8
9
|
from ultralytics.engine.results import Results
|
|
9
10
|
from ultralytics.utils import DEFAULT_CFG, ops
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class ClassificationPredictor(BasePredictor):
|
|
13
|
-
"""
|
|
14
|
-
A class extending the BasePredictor class for prediction based on a classification model.
|
|
14
|
+
"""A class extending the BasePredictor class for prediction based on a classification model.
|
|
15
15
|
|
|
16
|
-
This predictor handles the specific requirements of classification models, including preprocessing images
|
|
17
|
-
|
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images and
|
|
17
|
+
postprocessing predictions to generate classification results.
|
|
18
18
|
|
|
19
19
|
Attributes:
|
|
20
20
|
args (dict): Configuration arguments for the predictor.
|
|
21
|
-
_legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
|
|
22
21
|
|
|
23
22
|
Methods:
|
|
24
23
|
preprocess: Convert input images to model-compatible format.
|
|
25
24
|
postprocess: Process model predictions into Results objects.
|
|
26
25
|
|
|
27
|
-
Notes:
|
|
28
|
-
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
29
|
-
|
|
30
26
|
Examples:
|
|
31
27
|
>>> from ultralytics.utils import ASSETS
|
|
32
28
|
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
|
33
29
|
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
|
34
30
|
>>> predictor = ClassificationPredictor(overrides=args)
|
|
35
31
|
>>> predictor.predict_cli()
|
|
32
|
+
|
|
33
|
+
Notes:
|
|
34
|
+
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
36
35
|
"""
|
|
37
36
|
|
|
38
37
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
39
|
-
"""
|
|
40
|
-
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
38
|
+
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
|
41
39
|
|
|
42
40
|
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
|
43
41
|
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
|
-
cfg (dict): Default configuration dictionary containing prediction settings.
|
|
44
|
+
cfg (dict): Default configuration dictionary containing prediction settings.
|
|
47
45
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
48
46
|
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
|
49
47
|
"""
|
|
50
48
|
super().__init__(cfg, overrides, _callbacks)
|
|
51
49
|
self.args.task = "classify"
|
|
52
|
-
|
|
50
|
+
|
|
51
|
+
def setup_source(self, source):
|
|
52
|
+
"""Set up source and inference mode and classify transforms."""
|
|
53
|
+
super().setup_source(source)
|
|
54
|
+
updated = (
|
|
55
|
+
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
|
56
|
+
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
|
|
57
|
+
else False
|
|
58
|
+
)
|
|
59
|
+
self.transforms = (
|
|
60
|
+
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
|
|
61
|
+
)
|
|
53
62
|
|
|
54
63
|
def preprocess(self, img):
|
|
55
64
|
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
|
56
65
|
if not isinstance(img, torch.Tensor):
|
|
57
|
-
|
|
58
|
-
self.
|
|
66
|
+
img = torch.stack(
|
|
67
|
+
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
|
59
68
|
)
|
|
60
|
-
if is_legacy_transform: # to handle legacy transforms
|
|
61
|
-
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
|
62
|
-
else:
|
|
63
|
-
img = torch.stack(
|
|
64
|
-
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
|
65
|
-
)
|
|
66
69
|
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
|
67
|
-
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
|
70
|
+
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
|
68
71
|
|
|
69
72
|
def postprocess(self, preds, img, orig_imgs):
|
|
70
|
-
"""
|
|
71
|
-
Process predictions to return Results objects with classification probabilities.
|
|
73
|
+
"""Process predictions to return Results objects with classification probabilities.
|
|
72
74
|
|
|
73
75
|
Args:
|
|
74
76
|
preds (torch.Tensor): Raw predictions from the model.
|
|
75
77
|
img (torch.Tensor): Input images after preprocessing.
|
|
76
|
-
orig_imgs (
|
|
78
|
+
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
|
77
79
|
|
|
78
80
|
Returns:
|
|
79
|
-
(
|
|
81
|
+
(list[Results]): List of Results objects containing classification results for each image.
|
|
80
82
|
"""
|
|
81
|
-
if not isinstance(orig_imgs, list): #
|
|
83
|
+
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
|
82
84
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
83
85
|
|
|
84
86
|
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|