dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
3
7
|
import torch
|
|
4
8
|
import torch.nn as nn
|
|
5
9
|
import torch.nn.functional as F
|
|
@@ -14,23 +18,26 @@ from .tal import bbox2dist
|
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
class VarifocalLoss(nn.Module):
|
|
17
|
-
"""
|
|
18
|
-
Varifocal loss by Zhang et al.
|
|
21
|
+
"""Varifocal loss by Zhang et al.
|
|
19
22
|
|
|
20
|
-
|
|
23
|
+
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
|
24
|
+
hard-to-classify examples and balancing positive/negative samples.
|
|
21
25
|
|
|
22
|
-
|
|
26
|
+
Attributes:
|
|
23
27
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
24
28
|
alpha (float): The balancing factor used to address class imbalance.
|
|
29
|
+
|
|
30
|
+
References:
|
|
31
|
+
https://arxiv.org/abs/2008.13367
|
|
25
32
|
"""
|
|
26
33
|
|
|
27
|
-
def __init__(self, gamma=2.0, alpha=0.75):
|
|
28
|
-
"""Initialize the VarifocalLoss class."""
|
|
34
|
+
def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
|
|
35
|
+
"""Initialize the VarifocalLoss class with focusing and balancing parameters."""
|
|
29
36
|
super().__init__()
|
|
30
37
|
self.gamma = gamma
|
|
31
38
|
self.alpha = alpha
|
|
32
39
|
|
|
33
|
-
def forward(self, pred_score, gt_score, label):
|
|
40
|
+
def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
|
34
41
|
"""Compute varifocal loss between predictions and ground truth."""
|
|
35
42
|
weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
|
|
36
43
|
with autocast(enabled=False):
|
|
@@ -43,21 +50,23 @@ class VarifocalLoss(nn.Module):
|
|
|
43
50
|
|
|
44
51
|
|
|
45
52
|
class FocalLoss(nn.Module):
|
|
46
|
-
"""
|
|
47
|
-
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
53
|
+
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
|
48
54
|
|
|
49
|
-
|
|
55
|
+
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
|
|
56
|
+
hard negatives during training.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
50
59
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
51
|
-
alpha (
|
|
60
|
+
alpha (torch.Tensor): The balancing factor used to address class imbalance.
|
|
52
61
|
"""
|
|
53
62
|
|
|
54
|
-
def __init__(self, gamma=1.5, alpha=0.25):
|
|
55
|
-
"""Initialize FocalLoss class with
|
|
63
|
+
def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
|
|
64
|
+
"""Initialize FocalLoss class with focusing and balancing parameters."""
|
|
56
65
|
super().__init__()
|
|
57
66
|
self.gamma = gamma
|
|
58
67
|
self.alpha = torch.tensor(alpha)
|
|
59
68
|
|
|
60
|
-
def forward(self, pred, label):
|
|
69
|
+
def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
|
61
70
|
"""Calculate focal loss with modulating factors for class imbalance."""
|
|
62
71
|
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
|
63
72
|
# p_t = torch.exp(-loss)
|
|
@@ -78,12 +87,12 @@ class FocalLoss(nn.Module):
|
|
|
78
87
|
class DFLoss(nn.Module):
|
|
79
88
|
"""Criterion class for computing Distribution Focal Loss (DFL)."""
|
|
80
89
|
|
|
81
|
-
def __init__(self, reg_max=16) -> None:
|
|
90
|
+
def __init__(self, reg_max: int = 16) -> None:
|
|
82
91
|
"""Initialize the DFL module with regularization maximum."""
|
|
83
92
|
super().__init__()
|
|
84
93
|
self.reg_max = reg_max
|
|
85
94
|
|
|
86
|
-
def __call__(self, pred_dist, target):
|
|
95
|
+
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
87
96
|
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
|
|
88
97
|
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
|
89
98
|
tl = target.long() # target left
|
|
@@ -99,12 +108,21 @@ class DFLoss(nn.Module):
|
|
|
99
108
|
class BboxLoss(nn.Module):
|
|
100
109
|
"""Criterion class for computing training losses for bounding boxes."""
|
|
101
110
|
|
|
102
|
-
def __init__(self, reg_max=16):
|
|
111
|
+
def __init__(self, reg_max: int = 16):
|
|
103
112
|
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
|
104
113
|
super().__init__()
|
|
105
114
|
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
|
106
115
|
|
|
107
|
-
def forward(
|
|
116
|
+
def forward(
|
|
117
|
+
self,
|
|
118
|
+
pred_dist: torch.Tensor,
|
|
119
|
+
pred_bboxes: torch.Tensor,
|
|
120
|
+
anchor_points: torch.Tensor,
|
|
121
|
+
target_bboxes: torch.Tensor,
|
|
122
|
+
target_scores: torch.Tensor,
|
|
123
|
+
target_scores_sum: torch.Tensor,
|
|
124
|
+
fg_mask: torch.Tensor,
|
|
125
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
108
126
|
"""Compute IoU and DFL losses for bounding boxes."""
|
|
109
127
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
110
128
|
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
|
@@ -124,11 +142,20 @@ class BboxLoss(nn.Module):
|
|
|
124
142
|
class RotatedBboxLoss(BboxLoss):
|
|
125
143
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
|
126
144
|
|
|
127
|
-
def __init__(self, reg_max):
|
|
128
|
-
"""Initialize the
|
|
145
|
+
def __init__(self, reg_max: int):
|
|
146
|
+
"""Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
|
|
129
147
|
super().__init__(reg_max)
|
|
130
148
|
|
|
131
|
-
def forward(
|
|
149
|
+
def forward(
|
|
150
|
+
self,
|
|
151
|
+
pred_dist: torch.Tensor,
|
|
152
|
+
pred_bboxes: torch.Tensor,
|
|
153
|
+
anchor_points: torch.Tensor,
|
|
154
|
+
target_bboxes: torch.Tensor,
|
|
155
|
+
target_scores: torch.Tensor,
|
|
156
|
+
target_scores_sum: torch.Tensor,
|
|
157
|
+
fg_mask: torch.Tensor,
|
|
158
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
159
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
|
133
160
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
134
161
|
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
|
@@ -148,12 +175,14 @@ class RotatedBboxLoss(BboxLoss):
|
|
|
148
175
|
class KeypointLoss(nn.Module):
|
|
149
176
|
"""Criterion class for computing keypoint losses."""
|
|
150
177
|
|
|
151
|
-
def __init__(self, sigmas) -> None:
|
|
178
|
+
def __init__(self, sigmas: torch.Tensor) -> None:
|
|
152
179
|
"""Initialize the KeypointLoss class with keypoint sigmas."""
|
|
153
180
|
super().__init__()
|
|
154
181
|
self.sigmas = sigmas
|
|
155
182
|
|
|
156
|
-
def forward(
|
|
183
|
+
def forward(
|
|
184
|
+
self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
|
|
185
|
+
) -> torch.Tensor:
|
|
157
186
|
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
|
|
158
187
|
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
|
159
188
|
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
|
@@ -165,7 +194,7 @@ class KeypointLoss(nn.Module):
|
|
|
165
194
|
class v8DetectionLoss:
|
|
166
195
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
|
167
196
|
|
|
168
|
-
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
|
197
|
+
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
|
169
198
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
|
170
199
|
device = next(model.parameters()).device # get model device
|
|
171
200
|
h = model.args # hyperparameters
|
|
@@ -185,7 +214,7 @@ class v8DetectionLoss:
|
|
|
185
214
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
|
186
215
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
|
187
216
|
|
|
188
|
-
def preprocess(self, targets, batch_size, scale_tensor):
|
|
217
|
+
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
189
218
|
"""Preprocess targets by converting to tensor format and scaling coordinates."""
|
|
190
219
|
nl, ne = targets.shape
|
|
191
220
|
if nl == 0:
|
|
@@ -202,7 +231,7 @@ class v8DetectionLoss:
|
|
|
202
231
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
|
203
232
|
return out
|
|
204
233
|
|
|
205
|
-
def bbox_decode(self, anchor_points, pred_dist):
|
|
234
|
+
def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
|
|
206
235
|
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
|
207
236
|
if self.use_dfl:
|
|
208
237
|
b, a, c = pred_dist.shape # batch, anchors, channels
|
|
@@ -211,7 +240,7 @@ class v8DetectionLoss:
|
|
|
211
240
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
|
212
241
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
213
242
|
|
|
214
|
-
def __call__(self, preds, batch):
|
|
243
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
215
244
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
216
245
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
217
246
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
|
@@ -229,7 +258,7 @@ class v8DetectionLoss:
|
|
|
229
258
|
|
|
230
259
|
# Targets
|
|
231
260
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
232
|
-
targets = self.preprocess(targets
|
|
261
|
+
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
233
262
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
234
263
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
235
264
|
|
|
@@ -256,9 +285,14 @@ class v8DetectionLoss:
|
|
|
256
285
|
|
|
257
286
|
# Bbox loss
|
|
258
287
|
if fg_mask.sum():
|
|
259
|
-
target_bboxes /= stride_tensor
|
|
260
288
|
loss[0], loss[2] = self.bbox_loss(
|
|
261
|
-
pred_distri,
|
|
289
|
+
pred_distri,
|
|
290
|
+
pred_bboxes,
|
|
291
|
+
anchor_points,
|
|
292
|
+
target_bboxes / stride_tensor,
|
|
293
|
+
target_scores,
|
|
294
|
+
target_scores_sum,
|
|
295
|
+
fg_mask,
|
|
262
296
|
)
|
|
263
297
|
|
|
264
298
|
loss[0] *= self.hyp.box # box gain
|
|
@@ -276,7 +310,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
276
310
|
super().__init__(model)
|
|
277
311
|
self.overlap = model.args.overlap_mask
|
|
278
312
|
|
|
279
|
-
def __call__(self, preds, batch):
|
|
313
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
280
314
|
"""Calculate and return the combined loss for detection and segmentation."""
|
|
281
315
|
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
|
|
282
316
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
@@ -298,7 +332,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
298
332
|
try:
|
|
299
333
|
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
300
334
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
301
|
-
targets = self.preprocess(targets
|
|
335
|
+
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
302
336
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
303
337
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
304
338
|
except RuntimeError as e:
|
|
@@ -357,21 +391,20 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
357
391
|
loss[2] *= self.hyp.cls # cls gain
|
|
358
392
|
loss[3] *= self.hyp.dfl # dfl gain
|
|
359
393
|
|
|
360
|
-
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
394
|
+
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
|
361
395
|
|
|
362
396
|
@staticmethod
|
|
363
397
|
def single_mask_loss(
|
|
364
398
|
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
|
365
399
|
) -> torch.Tensor:
|
|
366
|
-
"""
|
|
367
|
-
Compute the instance segmentation loss for a single image.
|
|
400
|
+
"""Compute the instance segmentation loss for a single image.
|
|
368
401
|
|
|
369
402
|
Args:
|
|
370
|
-
gt_mask (torch.Tensor): Ground truth mask of shape (
|
|
371
|
-
pred (torch.Tensor): Predicted mask coefficients of shape (
|
|
403
|
+
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
|
404
|
+
pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
|
|
372
405
|
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
|
373
|
-
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (
|
|
374
|
-
area (torch.Tensor): Area of each ground truth bounding box of shape (
|
|
406
|
+
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
|
|
407
|
+
area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
|
|
375
408
|
|
|
376
409
|
Returns:
|
|
377
410
|
(torch.Tensor): The calculated mask loss for a single image.
|
|
@@ -396,8 +429,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
|
396
429
|
imgsz: torch.Tensor,
|
|
397
430
|
overlap: bool,
|
|
398
431
|
) -> torch.Tensor:
|
|
399
|
-
"""
|
|
400
|
-
Calculate the loss for instance segmentation.
|
|
432
|
+
"""Calculate the loss for instance segmentation.
|
|
401
433
|
|
|
402
434
|
Args:
|
|
403
435
|
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
|
@@ -464,7 +496,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
464
496
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
|
465
497
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
|
466
498
|
|
|
467
|
-
def __call__(self, preds, batch):
|
|
499
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
468
500
|
"""Calculate the total loss and detach it for pose estimation."""
|
|
469
501
|
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
|
470
502
|
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
|
@@ -485,7 +517,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
485
517
|
batch_size = pred_scores.shape[0]
|
|
486
518
|
batch_idx = batch["batch_idx"].view(-1, 1)
|
|
487
519
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
488
|
-
targets = self.preprocess(targets
|
|
520
|
+
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
489
521
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
490
522
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
491
523
|
|
|
@@ -531,7 +563,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
531
563
|
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
532
564
|
|
|
533
565
|
@staticmethod
|
|
534
|
-
def kpts_decode(anchor_points, pred_kpts):
|
|
566
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
|
535
567
|
"""Decode predicted keypoints to image coordinates."""
|
|
536
568
|
y = pred_kpts.clone()
|
|
537
569
|
y[..., :2] *= 2.0
|
|
@@ -540,10 +572,16 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
540
572
|
return y
|
|
541
573
|
|
|
542
574
|
def calculate_keypoints_loss(
|
|
543
|
-
self,
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
575
|
+
self,
|
|
576
|
+
masks: torch.Tensor,
|
|
577
|
+
target_gt_idx: torch.Tensor,
|
|
578
|
+
keypoints: torch.Tensor,
|
|
579
|
+
batch_idx: torch.Tensor,
|
|
580
|
+
stride_tensor: torch.Tensor,
|
|
581
|
+
target_bboxes: torch.Tensor,
|
|
582
|
+
pred_kpts: torch.Tensor,
|
|
583
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
584
|
+
"""Calculate the keypoints loss for the model.
|
|
547
585
|
|
|
548
586
|
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
|
549
587
|
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
|
@@ -609,12 +647,11 @@ class v8PoseLoss(v8DetectionLoss):
|
|
|
609
647
|
class v8ClassificationLoss:
|
|
610
648
|
"""Criterion class for computing training losses for classification."""
|
|
611
649
|
|
|
612
|
-
def __call__(self, preds, batch):
|
|
650
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
613
651
|
"""Compute the classification loss between predictions and true labels."""
|
|
614
652
|
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
|
615
653
|
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
|
616
|
-
|
|
617
|
-
return loss, loss_items
|
|
654
|
+
return loss, loss.detach()
|
|
618
655
|
|
|
619
656
|
|
|
620
657
|
class v8OBBLoss(v8DetectionLoss):
|
|
@@ -626,7 +663,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
626
663
|
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
|
627
664
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
|
628
665
|
|
|
629
|
-
def preprocess(self, targets, batch_size, scale_tensor):
|
|
666
|
+
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
|
630
667
|
"""Preprocess targets for oriented bounding box detection."""
|
|
631
668
|
if targets.shape[0] == 0:
|
|
632
669
|
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
|
@@ -643,7 +680,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
643
680
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
|
644
681
|
return out
|
|
645
682
|
|
|
646
|
-
def __call__(self, preds, batch):
|
|
683
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
647
684
|
"""Calculate and return the loss for oriented bounding box detection."""
|
|
648
685
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
649
686
|
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
|
@@ -667,7 +704,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
667
704
|
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
|
668
705
|
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
|
669
706
|
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
|
670
|
-
targets = self.preprocess(targets
|
|
707
|
+
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
671
708
|
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
|
672
709
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
673
710
|
except RuntimeError as e:
|
|
@@ -715,9 +752,10 @@ class v8OBBLoss(v8DetectionLoss):
|
|
|
715
752
|
|
|
716
753
|
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
|
717
754
|
|
|
718
|
-
def bbox_decode(
|
|
719
|
-
|
|
720
|
-
|
|
755
|
+
def bbox_decode(
|
|
756
|
+
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
|
757
|
+
) -> torch.Tensor:
|
|
758
|
+
"""Decode predicted object bounding box coordinates from anchor points and distribution.
|
|
721
759
|
|
|
722
760
|
Args:
|
|
723
761
|
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
|
@@ -741,7 +779,7 @@ class E2EDetectLoss:
|
|
|
741
779
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
|
742
780
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
|
743
781
|
|
|
744
|
-
def __call__(self, preds, batch):
|
|
782
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
745
783
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
|
746
784
|
preds = preds[1] if isinstance(preds, tuple) else preds
|
|
747
785
|
one2many = preds["one2many"]
|
|
@@ -762,7 +800,7 @@ class TVPDetectLoss:
|
|
|
762
800
|
self.ori_no = self.vp_criterion.no
|
|
763
801
|
self.ori_reg_max = self.vp_criterion.reg_max
|
|
764
802
|
|
|
765
|
-
def __call__(self, preds, batch):
|
|
803
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
766
804
|
"""Calculate the loss for text-visual prompt detection."""
|
|
767
805
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
|
768
806
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
@@ -776,7 +814,7 @@ class TVPDetectLoss:
|
|
|
776
814
|
box_loss = vp_loss[0][1]
|
|
777
815
|
return box_loss, vp_loss[1]
|
|
778
816
|
|
|
779
|
-
def _get_vp_features(self, feats):
|
|
817
|
+
def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
780
818
|
"""Extract visual-prompt features from the model output."""
|
|
781
819
|
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
|
|
782
820
|
|
|
@@ -798,7 +836,7 @@ class TVPSegmentLoss(TVPDetectLoss):
|
|
|
798
836
|
super().__init__(model)
|
|
799
837
|
self.vp_criterion = v8SegmentationLoss(model)
|
|
800
838
|
|
|
801
|
-
def __call__(self, preds, batch):
|
|
839
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
802
840
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
803
841
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
804
842
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|