dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/utils/instance.py
CHANGED
@@ -11,7 +11,7 @@ from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy,
|
|
11
11
|
|
12
12
|
|
13
13
|
def _ntuple(n):
|
14
|
-
"""
|
14
|
+
"""Create a function that converts input to n-tuple by repeating singleton values."""
|
15
15
|
|
16
16
|
def parse(x):
|
17
17
|
"""Parse input to return n-tuple by repeating singleton values n times."""
|
@@ -33,16 +33,29 @@ __all__ = ("Bboxes", "Instances") # tuple or list
|
|
33
33
|
|
34
34
|
class Bboxes:
|
35
35
|
"""
|
36
|
-
A class for handling bounding boxes.
|
36
|
+
A class for handling bounding boxes in multiple formats.
|
37
37
|
|
38
|
-
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'
|
39
|
-
Bounding box data should be provided
|
38
|
+
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
|
39
|
+
conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
|
40
40
|
|
41
41
|
Attributes:
|
42
42
|
bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).
|
43
43
|
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
|
44
44
|
|
45
|
-
|
45
|
+
Methods:
|
46
|
+
convert: Convert bounding box format from one type to another.
|
47
|
+
areas: Calculate the area of bounding boxes.
|
48
|
+
mul: Multiply bounding box coordinates by scale factor(s).
|
49
|
+
add: Add offset to bounding box coordinates.
|
50
|
+
concatenate: Concatenate multiple Bboxes objects.
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
Create bounding boxes in YOLO format
|
54
|
+
>>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh")
|
55
|
+
>>> bboxes.convert("xyxy")
|
56
|
+
>>> print(bboxes.areas())
|
57
|
+
|
58
|
+
Notes:
|
46
59
|
This class does not handle normalization or denormalization of bounding boxes.
|
47
60
|
"""
|
48
61
|
|
@@ -60,7 +73,6 @@ class Bboxes:
|
|
60
73
|
assert bboxes.shape[1] == 4
|
61
74
|
self.bboxes = bboxes
|
62
75
|
self.format = format
|
63
|
-
# self.normalized = normalized
|
64
76
|
|
65
77
|
def convert(self, format):
|
66
78
|
"""
|
@@ -82,36 +94,20 @@ class Bboxes:
|
|
82
94
|
self.format = format
|
83
95
|
|
84
96
|
def areas(self):
|
85
|
-
"""
|
97
|
+
"""Calculate the area of bounding boxes."""
|
86
98
|
return (
|
87
99
|
(self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy
|
88
100
|
if self.format == "xyxy"
|
89
101
|
else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
|
90
102
|
)
|
91
103
|
|
92
|
-
# def denormalize(self, w, h):
|
93
|
-
# if not self.normalized:
|
94
|
-
# return
|
95
|
-
# assert (self.bboxes <= 1.0).all()
|
96
|
-
# self.bboxes[:, 0::2] *= w
|
97
|
-
# self.bboxes[:, 1::2] *= h
|
98
|
-
# self.normalized = False
|
99
|
-
#
|
100
|
-
# def normalize(self, w, h):
|
101
|
-
# if self.normalized:
|
102
|
-
# return
|
103
|
-
# assert (self.bboxes > 1.0).any()
|
104
|
-
# self.bboxes[:, 0::2] /= w
|
105
|
-
# self.bboxes[:, 1::2] /= h
|
106
|
-
# self.normalized = True
|
107
|
-
|
108
104
|
def mul(self, scale):
|
109
105
|
"""
|
110
106
|
Multiply bounding box coordinates by scale factor(s).
|
111
107
|
|
112
108
|
Args:
|
113
|
-
scale (int | tuple | list): Scale factor(s) for four coordinates.
|
114
|
-
|
109
|
+
scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
|
110
|
+
all coordinates.
|
115
111
|
"""
|
116
112
|
if isinstance(scale, Number):
|
117
113
|
scale = to_4tuple(scale)
|
@@ -127,8 +123,8 @@ class Bboxes:
|
|
127
123
|
Add offset to bounding box coordinates.
|
128
124
|
|
129
125
|
Args:
|
130
|
-
offset (int | tuple | list): Offset(s) for four coordinates.
|
131
|
-
|
126
|
+
offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
|
127
|
+
all coordinates.
|
132
128
|
"""
|
133
129
|
if isinstance(offset, Number):
|
134
130
|
offset = to_4tuple(offset)
|
@@ -140,7 +136,7 @@ class Bboxes:
|
|
140
136
|
self.bboxes[:, 3] += offset[3]
|
141
137
|
|
142
138
|
def __len__(self):
|
143
|
-
"""Return the number of boxes."""
|
139
|
+
"""Return the number of bounding boxes."""
|
144
140
|
return len(self.bboxes)
|
145
141
|
|
146
142
|
@classmethod
|
@@ -155,7 +151,7 @@ class Bboxes:
|
|
155
151
|
Returns:
|
156
152
|
(Bboxes): A new Bboxes object containing the concatenated bounding boxes.
|
157
153
|
|
158
|
-
|
154
|
+
Notes:
|
159
155
|
The input should be a list or tuple of Bboxes objects.
|
160
156
|
"""
|
161
157
|
assert isinstance(boxes_list, (list, tuple))
|
@@ -172,18 +168,14 @@ class Bboxes:
|
|
172
168
|
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
173
169
|
|
174
170
|
Args:
|
175
|
-
index (int | slice | np.ndarray): The index, slice, or boolean array to select
|
176
|
-
the desired bounding boxes.
|
171
|
+
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
|
177
172
|
|
178
173
|
Returns:
|
179
174
|
(Bboxes): A new Bboxes object containing the selected bounding boxes.
|
180
175
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
Note:
|
185
|
-
When using boolean indexing, make sure to provide a boolean array with the same
|
186
|
-
length as the number of bounding boxes.
|
176
|
+
Notes:
|
177
|
+
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
|
178
|
+
bounding boxes.
|
187
179
|
"""
|
188
180
|
if isinstance(index, int):
|
189
181
|
return Bboxes(self.bboxes[index].reshape(1, -1))
|
@@ -196,6 +188,10 @@ class Instances:
|
|
196
188
|
"""
|
197
189
|
Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
198
190
|
|
191
|
+
This class provides a unified interface for handling different types of object annotations including bounding
|
192
|
+
boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,
|
193
|
+
and format conversion.
|
194
|
+
|
199
195
|
Attributes:
|
200
196
|
_bboxes (Bboxes): Internal object for handling bounding box operations.
|
201
197
|
keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).
|
@@ -216,6 +212,7 @@ class Instances:
|
|
216
212
|
concatenate: Concatenate multiple Instances objects.
|
217
213
|
|
218
214
|
Examples:
|
215
|
+
Create instances with bounding boxes and segments
|
219
216
|
>>> instances = Instances(
|
220
217
|
... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
|
221
218
|
... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
|
@@ -225,14 +222,14 @@ class Instances:
|
|
225
222
|
|
226
223
|
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
227
224
|
"""
|
228
|
-
Initialize the object with bounding boxes, segments, and keypoints.
|
225
|
+
Initialize the Instances object with bounding boxes, segments, and keypoints.
|
229
226
|
|
230
227
|
Args:
|
231
|
-
bboxes (np.ndarray): Bounding boxes
|
228
|
+
bboxes (np.ndarray): Bounding boxes with shape (N, 4).
|
232
229
|
segments (List | np.ndarray, optional): Segmentation masks.
|
233
|
-
keypoints (np.ndarray, optional): Keypoints
|
234
|
-
bbox_format (str
|
235
|
-
normalized (bool
|
230
|
+
keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).
|
231
|
+
bbox_format (str): Format of bboxes.
|
232
|
+
normalized (bool): Whether the coordinates are normalized.
|
236
233
|
"""
|
237
234
|
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
238
235
|
self.keypoints = keypoints
|
@@ -333,9 +330,9 @@ class Instances:
|
|
333
330
|
Returns:
|
334
331
|
(Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.
|
335
332
|
|
336
|
-
|
337
|
-
When using boolean indexing, make sure to provide a boolean array with the same
|
338
|
-
|
333
|
+
Notes:
|
334
|
+
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
|
335
|
+
instances.
|
339
336
|
"""
|
340
337
|
segments = self.segments[index] if len(self.segments) else self.segments
|
341
338
|
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
@@ -442,7 +439,7 @@ class Instances:
|
|
442
439
|
self.keypoints = keypoints
|
443
440
|
|
444
441
|
def __len__(self):
|
445
|
-
"""Return the
|
442
|
+
"""Return the number of instances."""
|
446
443
|
return len(self.bboxes)
|
447
444
|
|
448
445
|
@classmethod
|
@@ -455,13 +452,12 @@ class Instances:
|
|
455
452
|
axis (int, optional): The axis along which the arrays will be concatenated.
|
456
453
|
|
457
454
|
Returns:
|
458
|
-
(Instances): A new Instances object containing the concatenated bounding boxes,
|
459
|
-
|
455
|
+
(Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
|
456
|
+
if present.
|
460
457
|
|
461
|
-
|
462
|
-
The `Instances` objects in the list should have the same properties, such as
|
463
|
-
|
464
|
-
coordinates are normalized.
|
458
|
+
Notes:
|
459
|
+
The `Instances` objects in the list should have the same properties, such as the format of the bounding
|
460
|
+
boxes, whether keypoints are present, and if the coordinates are normalized.
|
465
461
|
"""
|
466
462
|
assert isinstance(instances_list, (list, tuple))
|
467
463
|
if not instances_list:
|
ultralytics/utils/loss.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
5
7
|
import torch.nn.functional as F
|
@@ -17,20 +19,24 @@ class VarifocalLoss(nn.Module):
|
|
17
19
|
"""
|
18
20
|
Varifocal loss by Zhang et al.
|
19
21
|
|
20
|
-
|
22
|
+
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
23
|
+
hard-to-classify examples and balancing positive/negative samples.
|
21
24
|
|
22
|
-
|
25
|
+
Attributes:
|
23
26
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
24
27
|
alpha (float): The balancing factor used to address class imbalance.
|
28
|
+
|
29
|
+
References:
|
30
|
+
https://arxiv.org/abs/2008.13367
|
25
31
|
"""
|
26
32
|
|
27
|
-
def __init__(self, gamma=2.0, alpha=0.75):
|
28
|
-
"""Initialize the VarifocalLoss class."""
|
33
|
+
def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
|
34
|
+
"""Initialize the VarifocalLoss class with focusing and balancing parameters."""
|
29
35
|
super().__init__()
|
30
36
|
self.gamma = gamma
|
31
37
|
self.alpha = alpha
|
32
38
|
|
33
|
-
def forward(self, pred_score, gt_score, label):
|
39
|
+
def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
34
40
|
"""Compute varifocal loss between predictions and ground truth."""
|
35
41
|
weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
|
36
42
|
with autocast(enabled=False):
|
@@ -46,18 +52,21 @@ class FocalLoss(nn.Module):
|
|
46
52
|
"""
|
47
53
|
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
48
54
|
|
49
|
-
|
55
|
+
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
|
56
|
+
on hard negatives during training.
|
57
|
+
|
58
|
+
Attributes:
|
50
59
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
51
|
-
alpha (
|
60
|
+
alpha (torch.Tensor): The balancing factor used to address class imbalance.
|
52
61
|
"""
|
53
62
|
|
54
|
-
def __init__(self, gamma=1.5, alpha=0.25):
|
55
|
-
"""Initialize FocalLoss class with
|
63
|
+
def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
|
64
|
+
"""Initialize FocalLoss class with focusing and balancing parameters."""
|
56
65
|
super().__init__()
|
57
66
|
self.gamma = gamma
|
58
67
|
self.alpha = torch.tensor(alpha)
|
59
68
|
|
60
|
-
def forward(self, pred, label):
|
69
|
+
def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
61
70
|
"""Calculate focal loss with modulating factors for class imbalance."""
|
62
71
|
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
63
72
|
# p_t = torch.exp(-loss)
|
@@ -78,12 +87,12 @@ class FocalLoss(nn.Module):
|
|
78
87
|
class DFLoss(nn.Module):
|
79
88
|
"""Criterion class for computing Distribution Focal Loss (DFL)."""
|
80
89
|
|
81
|
-
def __init__(self, reg_max=16) -> None:
|
90
|
+
def __init__(self, reg_max: int = 16) -> None:
|
82
91
|
"""Initialize the DFL module with regularization maximum."""
|
83
92
|
super().__init__()
|
84
93
|
self.reg_max = reg_max
|
85
94
|
|
86
|
-
def __call__(self, pred_dist, target):
|
95
|
+
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
87
96
|
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
|
88
97
|
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
89
98
|
tl = target.long() # target left
|
@@ -99,12 +108,21 @@ class DFLoss(nn.Module):
|
|
99
108
|
class BboxLoss(nn.Module):
|
100
109
|
"""Criterion class for computing training losses for bounding boxes."""
|
101
110
|
|
102
|
-
def __init__(self, reg_max=16):
|
111
|
+
def __init__(self, reg_max: int = 16):
|
103
112
|
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
104
113
|
super().__init__()
|
105
114
|
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
106
115
|
|
107
|
-
def forward(
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
pred_dist: torch.Tensor,
|
119
|
+
pred_bboxes: torch.Tensor,
|
120
|
+
anchor_points: torch.Tensor,
|
121
|
+
target_bboxes: torch.Tensor,
|
122
|
+
target_scores: torch.Tensor,
|
123
|
+
target_scores_sum: torch.Tensor,
|
124
|
+
fg_mask: torch.Tensor,
|
125
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
108
126
|
"""Compute IoU and DFL losses for bounding boxes."""
|
109
127
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
110
128
|
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
@@ -124,11 +142,20 @@ class BboxLoss(nn.Module):
|
|
124
142
|
class RotatedBboxLoss(BboxLoss):
|
125
143
|
"""Criterion class for computing training losses for rotated bounding boxes."""
|
126
144
|
|
127
|
-
def __init__(self, reg_max):
|
128
|
-
"""Initialize the
|
145
|
+
def __init__(self, reg_max: int):
|
146
|
+
"""Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
|
129
147
|
super().__init__(reg_max)
|
130
148
|
|
131
|
-
def forward(
|
149
|
+
def forward(
|
150
|
+
self,
|
151
|
+
pred_dist: torch.Tensor,
|
152
|
+
pred_bboxes: torch.Tensor,
|
153
|
+
anchor_points: torch.Tensor,
|
154
|
+
target_bboxes: torch.Tensor,
|
155
|
+
target_scores: torch.Tensor,
|
156
|
+
target_scores_sum: torch.Tensor,
|
157
|
+
fg_mask: torch.Tensor,
|
158
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
132
159
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
133
160
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
134
161
|
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
@@ -148,12 +175,14 @@ class RotatedBboxLoss(BboxLoss):
|
|
148
175
|
class KeypointLoss(nn.Module):
|
149
176
|
"""Criterion class for computing keypoint losses."""
|
150
177
|
|
151
|
-
def __init__(self, sigmas) -> None:
|
178
|
+
def __init__(self, sigmas: torch.Tensor) -> None:
|
152
179
|
"""Initialize the KeypointLoss class with keypoint sigmas."""
|
153
180
|
super().__init__()
|
154
181
|
self.sigmas = sigmas
|
155
182
|
|
156
|
-
def forward(
|
183
|
+
def forward(
|
184
|
+
self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
|
185
|
+
) -> torch.Tensor:
|
157
186
|
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
|
158
187
|
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
159
188
|
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
@@ -165,7 +194,7 @@ class KeypointLoss(nn.Module):
|
|
165
194
|
class v8DetectionLoss:
|
166
195
|
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
167
196
|
|
168
|
-
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
197
|
+
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
169
198
|
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
170
199
|
device = next(model.parameters()).device # get model device
|
171
200
|
h = model.args # hyperparameters
|
@@ -185,7 +214,7 @@ class v8DetectionLoss:
|
|
185
214
|
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
186
215
|
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
187
216
|
|
188
|
-
def preprocess(self, targets, batch_size, scale_tensor):
|
217
|
+
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
189
218
|
"""Preprocess targets by converting to tensor format and scaling coordinates."""
|
190
219
|
nl, ne = targets.shape
|
191
220
|
if nl == 0:
|
@@ -202,7 +231,7 @@ class v8DetectionLoss:
|
|
202
231
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
203
232
|
return out
|
204
233
|
|
205
|
-
def bbox_decode(self, anchor_points, pred_dist):
|
234
|
+
def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
|
206
235
|
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
207
236
|
if self.use_dfl:
|
208
237
|
b, a, c = pred_dist.shape # batch, anchors, channels
|
@@ -211,7 +240,7 @@ class v8DetectionLoss:
|
|
211
240
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
212
241
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
213
242
|
|
214
|
-
def __call__(self, preds, batch):
|
243
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
215
244
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
216
245
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
217
246
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
@@ -276,7 +305,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
276
305
|
super().__init__(model)
|
277
306
|
self.overlap = model.args.overlap_mask
|
278
307
|
|
279
|
-
def __call__(self, preds, batch):
|
308
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
280
309
|
"""Calculate and return the combined loss for detection and segmentation."""
|
281
310
|
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
|
282
311
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
@@ -367,11 +396,11 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
367
396
|
Compute the instance segmentation loss for a single image.
|
368
397
|
|
369
398
|
Args:
|
370
|
-
gt_mask (torch.Tensor): Ground truth mask of shape (
|
371
|
-
pred (torch.Tensor): Predicted mask coefficients of shape (
|
399
|
+
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
400
|
+
pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
|
372
401
|
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
373
|
-
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (
|
374
|
-
area (torch.Tensor): Area of each ground truth bounding box of shape (
|
402
|
+
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
|
403
|
+
area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
|
375
404
|
|
376
405
|
Returns:
|
377
406
|
(torch.Tensor): The calculated mask loss for a single image.
|
@@ -464,7 +493,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
464
493
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
465
494
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
466
495
|
|
467
|
-
def __call__(self, preds, batch):
|
496
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
468
497
|
"""Calculate the total loss and detach it for pose estimation."""
|
469
498
|
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
470
499
|
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
@@ -531,7 +560,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
531
560
|
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
532
561
|
|
533
562
|
@staticmethod
|
534
|
-
def kpts_decode(anchor_points, pred_kpts):
|
563
|
+
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
535
564
|
"""Decode predicted keypoints to image coordinates."""
|
536
565
|
y = pred_kpts.clone()
|
537
566
|
y[..., :2] *= 2.0
|
@@ -540,8 +569,15 @@ class v8PoseLoss(v8DetectionLoss):
|
|
540
569
|
return y
|
541
570
|
|
542
571
|
def calculate_keypoints_loss(
|
543
|
-
self,
|
544
|
-
|
572
|
+
self,
|
573
|
+
masks: torch.Tensor,
|
574
|
+
target_gt_idx: torch.Tensor,
|
575
|
+
keypoints: torch.Tensor,
|
576
|
+
batch_idx: torch.Tensor,
|
577
|
+
stride_tensor: torch.Tensor,
|
578
|
+
target_bboxes: torch.Tensor,
|
579
|
+
pred_kpts: torch.Tensor,
|
580
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
545
581
|
"""
|
546
582
|
Calculate the keypoints loss for the model.
|
547
583
|
|
@@ -609,7 +645,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
609
645
|
class v8ClassificationLoss:
|
610
646
|
"""Criterion class for computing training losses for classification."""
|
611
647
|
|
612
|
-
def __call__(self, preds, batch):
|
648
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
613
649
|
"""Compute the classification loss between predictions and true labels."""
|
614
650
|
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
615
651
|
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
@@ -625,7 +661,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
625
661
|
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
626
662
|
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
627
663
|
|
628
|
-
def preprocess(self, targets, batch_size, scale_tensor):
|
664
|
+
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
629
665
|
"""Preprocess targets for oriented bounding box detection."""
|
630
666
|
if targets.shape[0] == 0:
|
631
667
|
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
@@ -642,7 +678,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
642
678
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
643
679
|
return out
|
644
680
|
|
645
|
-
def __call__(self, preds, batch):
|
681
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
646
682
|
"""Calculate and return the loss for oriented bounding box detection."""
|
647
683
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
648
684
|
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
@@ -714,7 +750,9 @@ class v8OBBLoss(v8DetectionLoss):
|
|
714
750
|
|
715
751
|
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
716
752
|
|
717
|
-
def bbox_decode(
|
753
|
+
def bbox_decode(
|
754
|
+
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
755
|
+
) -> torch.Tensor:
|
718
756
|
"""
|
719
757
|
Decode predicted object bounding box coordinates from anchor points and distribution.
|
720
758
|
|
@@ -740,7 +778,7 @@ class E2EDetectLoss:
|
|
740
778
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
741
779
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
742
780
|
|
743
|
-
def __call__(self, preds, batch):
|
781
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
744
782
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
745
783
|
preds = preds[1] if isinstance(preds, tuple) else preds
|
746
784
|
one2many = preds["one2many"]
|
@@ -761,7 +799,7 @@ class TVPDetectLoss:
|
|
761
799
|
self.ori_no = self.vp_criterion.no
|
762
800
|
self.ori_reg_max = self.vp_criterion.reg_max
|
763
801
|
|
764
|
-
def __call__(self, preds, batch):
|
802
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
765
803
|
"""Calculate the loss for text-visual prompt detection."""
|
766
804
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
767
805
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
@@ -775,7 +813,7 @@ class TVPDetectLoss:
|
|
775
813
|
box_loss = vp_loss[0][1]
|
776
814
|
return box_loss, vp_loss[1]
|
777
815
|
|
778
|
-
def _get_vp_features(self, feats):
|
816
|
+
def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
|
779
817
|
"""Extract visual-prompt features from the model output."""
|
780
818
|
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
|
781
819
|
|
@@ -797,7 +835,7 @@ class TVPSegmentLoss(TVPDetectLoss):
|
|
797
835
|
super().__init__(model)
|
798
836
|
self.vp_criterion = v8SegmentationLoss(model)
|
799
837
|
|
800
|
-
def __call__(self, preds, batch):
|
838
|
+
def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
801
839
|
"""Calculate the loss for text-visual prompt segmentation."""
|
802
840
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
803
841
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|