ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/models/utils/loss.py
CHANGED
|
@@ -30,14 +30,9 @@ class DETRLoss(nn.Module):
|
|
|
30
30
|
device (torch.device): Device on which tensors are stored.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
def __init__(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
aux_loss=True,
|
|
37
|
-
use_fl=True,
|
|
38
|
-
use_vfl=False,
|
|
39
|
-
use_uni_match=False,
|
|
40
|
-
uni_match_ind=0):
|
|
33
|
+
def __init__(
|
|
34
|
+
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
|
|
35
|
+
):
|
|
41
36
|
"""
|
|
42
37
|
DETR loss function.
|
|
43
38
|
|
|
@@ -52,9 +47,9 @@ class DETRLoss(nn.Module):
|
|
|
52
47
|
super().__init__()
|
|
53
48
|
|
|
54
49
|
if loss_gain is None:
|
|
55
|
-
loss_gain = {
|
|
50
|
+
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
|
|
56
51
|
self.nc = nc
|
|
57
|
-
self.matcher = HungarianMatcher(cost_gain={
|
|
52
|
+
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
|
58
53
|
self.loss_gain = loss_gain
|
|
59
54
|
self.aux_loss = aux_loss
|
|
60
55
|
self.fl = FocalLoss() if use_fl else None
|
|
@@ -64,10 +59,10 @@ class DETRLoss(nn.Module):
|
|
|
64
59
|
self.uni_match_ind = uni_match_ind
|
|
65
60
|
self.device = None
|
|
66
61
|
|
|
67
|
-
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=
|
|
62
|
+
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
|
|
68
63
|
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
|
|
69
64
|
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
|
70
|
-
name_class = f
|
|
65
|
+
name_class = f"loss_class{postfix}"
|
|
71
66
|
bs, nq = pred_scores.shape[:2]
|
|
72
67
|
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
|
73
68
|
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
|
@@ -82,28 +77,28 @@ class DETRLoss(nn.Module):
|
|
|
82
77
|
loss_cls = self.fl(pred_scores, one_hot.float())
|
|
83
78
|
loss_cls /= max(num_gts, 1) / nq
|
|
84
79
|
else:
|
|
85
|
-
loss_cls = nn.BCEWithLogitsLoss(reduction=
|
|
80
|
+
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
|
86
81
|
|
|
87
|
-
return {name_class: loss_cls.squeeze() * self.loss_gain[
|
|
82
|
+
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
|
88
83
|
|
|
89
|
-
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=
|
|
84
|
+
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
|
90
85
|
"""Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
|
|
91
86
|
boxes.
|
|
92
87
|
"""
|
|
93
88
|
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
|
94
|
-
name_bbox = f
|
|
95
|
-
name_giou = f
|
|
89
|
+
name_bbox = f"loss_bbox{postfix}"
|
|
90
|
+
name_giou = f"loss_giou{postfix}"
|
|
96
91
|
|
|
97
92
|
loss = {}
|
|
98
93
|
if len(gt_bboxes) == 0:
|
|
99
|
-
loss[name_bbox] = torch.tensor(0
|
|
100
|
-
loss[name_giou] = torch.tensor(0
|
|
94
|
+
loss[name_bbox] = torch.tensor(0.0, device=self.device)
|
|
95
|
+
loss[name_giou] = torch.tensor(0.0, device=self.device)
|
|
101
96
|
return loss
|
|
102
97
|
|
|
103
|
-
loss[name_bbox] = self.loss_gain[
|
|
98
|
+
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
|
|
104
99
|
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
|
105
100
|
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
|
106
|
-
loss[name_giou] = self.loss_gain[
|
|
101
|
+
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
|
|
107
102
|
return {k: v.squeeze() for k, v in loss.items()}
|
|
108
103
|
|
|
109
104
|
# This function is for future RT-DETR Segment models
|
|
@@ -137,50 +132,57 @@ class DETRLoss(nn.Module):
|
|
|
137
132
|
# loss = 1 - (numerator + 1) / (denominator + 1)
|
|
138
133
|
# return loss.sum() / num_gts
|
|
139
134
|
|
|
140
|
-
def _get_loss_aux(
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
135
|
+
def _get_loss_aux(
|
|
136
|
+
self,
|
|
137
|
+
pred_bboxes,
|
|
138
|
+
pred_scores,
|
|
139
|
+
gt_bboxes,
|
|
140
|
+
gt_cls,
|
|
141
|
+
gt_groups,
|
|
142
|
+
match_indices=None,
|
|
143
|
+
postfix="",
|
|
144
|
+
masks=None,
|
|
145
|
+
gt_mask=None,
|
|
146
|
+
):
|
|
150
147
|
"""Get auxiliary losses."""
|
|
151
148
|
# NOTE: loss class, bbox, giou, mask, dice
|
|
152
149
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
|
153
150
|
if match_indices is None and self.use_uni_match:
|
|
154
|
-
match_indices = self.matcher(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
151
|
+
match_indices = self.matcher(
|
|
152
|
+
pred_bboxes[self.uni_match_ind],
|
|
153
|
+
pred_scores[self.uni_match_ind],
|
|
154
|
+
gt_bboxes,
|
|
155
|
+
gt_cls,
|
|
156
|
+
gt_groups,
|
|
157
|
+
masks=masks[self.uni_match_ind] if masks is not None else None,
|
|
158
|
+
gt_mask=gt_mask,
|
|
159
|
+
)
|
|
161
160
|
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
|
162
161
|
aux_masks = masks[i] if masks is not None else None
|
|
163
|
-
loss_ = self._get_loss(
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
loss[
|
|
162
|
+
loss_ = self._get_loss(
|
|
163
|
+
aux_bboxes,
|
|
164
|
+
aux_scores,
|
|
165
|
+
gt_bboxes,
|
|
166
|
+
gt_cls,
|
|
167
|
+
gt_groups,
|
|
168
|
+
masks=aux_masks,
|
|
169
|
+
gt_mask=gt_mask,
|
|
170
|
+
postfix=postfix,
|
|
171
|
+
match_indices=match_indices,
|
|
172
|
+
)
|
|
173
|
+
loss[0] += loss_[f"loss_class{postfix}"]
|
|
174
|
+
loss[1] += loss_[f"loss_bbox{postfix}"]
|
|
175
|
+
loss[2] += loss_[f"loss_giou{postfix}"]
|
|
175
176
|
# if masks is not None and gt_mask is not None:
|
|
176
177
|
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
|
177
178
|
# loss[3] += loss_[f'loss_mask{postfix}']
|
|
178
179
|
# loss[4] += loss_[f'loss_dice{postfix}']
|
|
179
180
|
|
|
180
181
|
loss = {
|
|
181
|
-
f
|
|
182
|
-
f
|
|
183
|
-
f
|
|
182
|
+
f"loss_class_aux{postfix}": loss[0],
|
|
183
|
+
f"loss_bbox_aux{postfix}": loss[1],
|
|
184
|
+
f"loss_giou_aux{postfix}": loss[2],
|
|
185
|
+
}
|
|
184
186
|
# if masks is not None and gt_mask is not None:
|
|
185
187
|
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
|
186
188
|
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
|
@@ -196,33 +198,37 @@ class DETRLoss(nn.Module):
|
|
|
196
198
|
|
|
197
199
|
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
|
198
200
|
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
|
|
199
|
-
pred_assigned = torch.cat(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
201
|
+
pred_assigned = torch.cat(
|
|
202
|
+
[
|
|
203
|
+
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
|
204
|
+
for t, (I, _) in zip(pred_bboxes, match_indices)
|
|
205
|
+
]
|
|
206
|
+
)
|
|
207
|
+
gt_assigned = torch.cat(
|
|
208
|
+
[
|
|
209
|
+
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
|
210
|
+
for t, (_, J) in zip(gt_bboxes, match_indices)
|
|
211
|
+
]
|
|
212
|
+
)
|
|
205
213
|
return pred_assigned, gt_assigned
|
|
206
214
|
|
|
207
|
-
def _get_loss(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
215
|
+
def _get_loss(
|
|
216
|
+
self,
|
|
217
|
+
pred_bboxes,
|
|
218
|
+
pred_scores,
|
|
219
|
+
gt_bboxes,
|
|
220
|
+
gt_cls,
|
|
221
|
+
gt_groups,
|
|
222
|
+
masks=None,
|
|
223
|
+
gt_mask=None,
|
|
224
|
+
postfix="",
|
|
225
|
+
match_indices=None,
|
|
226
|
+
):
|
|
217
227
|
"""Get losses."""
|
|
218
228
|
if match_indices is None:
|
|
219
|
-
match_indices = self.matcher(
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
gt_cls,
|
|
223
|
-
gt_groups,
|
|
224
|
-
masks=masks,
|
|
225
|
-
gt_mask=gt_mask)
|
|
229
|
+
match_indices = self.matcher(
|
|
230
|
+
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
|
231
|
+
)
|
|
226
232
|
|
|
227
233
|
idx, gt_idx = self._get_index(match_indices)
|
|
228
234
|
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
|
@@ -242,7 +248,7 @@ class DETRLoss(nn.Module):
|
|
|
242
248
|
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
|
243
249
|
return loss
|
|
244
250
|
|
|
245
|
-
def forward(self, pred_bboxes, pred_scores, batch, postfix=
|
|
251
|
+
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
|
|
246
252
|
"""
|
|
247
253
|
Args:
|
|
248
254
|
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
|
@@ -254,21 +260,19 @@ class DETRLoss(nn.Module):
|
|
|
254
260
|
postfix (str): postfix of loss name.
|
|
255
261
|
"""
|
|
256
262
|
self.device = pred_bboxes.device
|
|
257
|
-
match_indices = kwargs.get(
|
|
258
|
-
gt_cls, gt_bboxes, gt_groups = batch[
|
|
263
|
+
match_indices = kwargs.get("match_indices", None)
|
|
264
|
+
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
|
|
259
265
|
|
|
260
|
-
total_loss = self._get_loss(
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
gt_cls,
|
|
264
|
-
gt_groups,
|
|
265
|
-
postfix=postfix,
|
|
266
|
-
match_indices=match_indices)
|
|
266
|
+
total_loss = self._get_loss(
|
|
267
|
+
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
|
|
268
|
+
)
|
|
267
269
|
|
|
268
270
|
if self.aux_loss:
|
|
269
271
|
total_loss.update(
|
|
270
|
-
self._get_loss_aux(
|
|
271
|
-
|
|
272
|
+
self._get_loss_aux(
|
|
273
|
+
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
|
|
274
|
+
)
|
|
275
|
+
)
|
|
272
276
|
|
|
273
277
|
return total_loss
|
|
274
278
|
|
|
@@ -300,18 +304,18 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
300
304
|
|
|
301
305
|
# Check for denoising metadata to compute denoising training loss
|
|
302
306
|
if dn_meta is not None:
|
|
303
|
-
dn_pos_idx, dn_num_group = dn_meta[
|
|
304
|
-
assert len(batch[
|
|
307
|
+
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
|
|
308
|
+
assert len(batch["gt_groups"]) == len(dn_pos_idx)
|
|
305
309
|
|
|
306
310
|
# Get the match indices for denoising
|
|
307
|
-
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch[
|
|
311
|
+
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
|
|
308
312
|
|
|
309
313
|
# Compute the denoising training loss
|
|
310
|
-
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix=
|
|
314
|
+
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
|
|
311
315
|
total_loss.update(dn_loss)
|
|
312
316
|
else:
|
|
313
317
|
# If no denoising metadata is provided, set denoising loss to zero
|
|
314
|
-
total_loss.update({f
|
|
318
|
+
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
|
|
315
319
|
|
|
316
320
|
return total_loss
|
|
317
321
|
|
|
@@ -334,8 +338,8 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
334
338
|
if num_gt > 0:
|
|
335
339
|
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
|
336
340
|
gt_idx = gt_idx.repeat(dn_num_group)
|
|
337
|
-
assert len(dn_pos_idx[i]) == len(gt_idx),
|
|
338
|
-
f
|
|
341
|
+
assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
|
|
342
|
+
f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
|
339
343
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
|
340
344
|
else:
|
|
341
345
|
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
ultralytics/models/utils/ops.py
CHANGED
|
@@ -37,7 +37,7 @@ class HungarianMatcher(nn.Module):
|
|
|
37
37
|
"""
|
|
38
38
|
super().__init__()
|
|
39
39
|
if cost_gain is None:
|
|
40
|
-
cost_gain = {
|
|
40
|
+
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
|
41
41
|
self.cost_gain = cost_gain
|
|
42
42
|
self.use_fl = use_fl
|
|
43
43
|
self.with_mask = with_mask
|
|
@@ -86,7 +86,7 @@ class HungarianMatcher(nn.Module):
|
|
|
86
86
|
# Compute the classification cost
|
|
87
87
|
pred_scores = pred_scores[:, gt_cls]
|
|
88
88
|
if self.use_fl:
|
|
89
|
-
neg_cost_class = (1 - self.alpha) * (pred_scores
|
|
89
|
+
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
|
90
90
|
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
|
91
91
|
cost_class = pos_cost_class - neg_cost_class
|
|
92
92
|
else:
|
|
@@ -99,9 +99,11 @@ class HungarianMatcher(nn.Module):
|
|
|
99
99
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
|
100
100
|
|
|
101
101
|
# Final cost matrix
|
|
102
|
-
C =
|
|
103
|
-
self.cost_gain[
|
|
104
|
-
self.cost_gain[
|
|
102
|
+
C = (
|
|
103
|
+
self.cost_gain["class"] * cost_class
|
|
104
|
+
+ self.cost_gain["bbox"] * cost_bbox
|
|
105
|
+
+ self.cost_gain["giou"] * cost_giou
|
|
106
|
+
)
|
|
105
107
|
# Compute the mask cost and dice cost
|
|
106
108
|
if self.with_mask:
|
|
107
109
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
|
@@ -111,10 +113,11 @@ class HungarianMatcher(nn.Module):
|
|
|
111
113
|
|
|
112
114
|
C = C.view(bs, nq, -1).cpu()
|
|
113
115
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
|
114
|
-
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
|
117
|
+
return [
|
|
118
|
+
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
|
119
|
+
for k, (i, j) in enumerate(indices)
|
|
120
|
+
]
|
|
118
121
|
|
|
119
122
|
# This function is for future RT-DETR Segment models
|
|
120
123
|
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
|
@@ -147,14 +150,9 @@ class HungarianMatcher(nn.Module):
|
|
|
147
150
|
# return C
|
|
148
151
|
|
|
149
152
|
|
|
150
|
-
def get_cdn_group(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class_embed,
|
|
154
|
-
num_dn=100,
|
|
155
|
-
cls_noise_ratio=0.5,
|
|
156
|
-
box_noise_scale=1.0,
|
|
157
|
-
training=False):
|
|
153
|
+
def get_cdn_group(
|
|
154
|
+
batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
|
|
155
|
+
):
|
|
158
156
|
"""
|
|
159
157
|
Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
|
|
160
158
|
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
|
@@ -180,7 +178,7 @@ def get_cdn_group(batch,
|
|
|
180
178
|
|
|
181
179
|
if (not training) or num_dn <= 0:
|
|
182
180
|
return None, None, None, None
|
|
183
|
-
gt_groups = batch[
|
|
181
|
+
gt_groups = batch["gt_groups"]
|
|
184
182
|
total_num = sum(gt_groups)
|
|
185
183
|
max_nums = max(gt_groups)
|
|
186
184
|
if max_nums == 0:
|
|
@@ -190,9 +188,9 @@ def get_cdn_group(batch,
|
|
|
190
188
|
num_group = 1 if num_group == 0 else num_group
|
|
191
189
|
# Pad gt to max_num of a batch
|
|
192
190
|
bs = len(gt_groups)
|
|
193
|
-
gt_cls = batch[
|
|
194
|
-
gt_bbox = batch[
|
|
195
|
-
b_idx = batch[
|
|
191
|
+
gt_cls = batch["cls"] # (bs*num, )
|
|
192
|
+
gt_bbox = batch["bboxes"] # bs*num, 4
|
|
193
|
+
b_idx = batch["batch_idx"]
|
|
196
194
|
|
|
197
195
|
# Each group has positive and negative queries.
|
|
198
196
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
|
@@ -245,16 +243,21 @@ def get_cdn_group(batch,
|
|
|
245
243
|
# Reconstruct cannot see each other
|
|
246
244
|
for i in range(num_group):
|
|
247
245
|
if i == 0:
|
|
248
|
-
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
|
246
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
|
249
247
|
if i == num_group - 1:
|
|
250
|
-
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
|
248
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
|
251
249
|
else:
|
|
252
|
-
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
|
253
|
-
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
|
250
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
|
251
|
+
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
|
254
252
|
dn_meta = {
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
253
|
+
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
|
254
|
+
"dn_num_group": num_group,
|
|
255
|
+
"dn_num_split": [num_dn, num_queries],
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
return (
|
|
259
|
+
padding_cls.to(class_embed.device),
|
|
260
|
+
padding_bbox.to(class_embed.device),
|
|
261
|
+
attn_mask.to(class_embed.device),
|
|
262
|
+
dn_meta,
|
|
263
|
+
)
|
|
@@ -4,4 +4,4 @@ from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
|
|
4
4
|
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
|
5
5
|
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
|
6
6
|
|
|
7
|
-
__all__ =
|
|
7
|
+
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
|
@@ -30,19 +30,21 @@ class ClassificationPredictor(BasePredictor):
|
|
|
30
30
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
31
31
|
"""Initializes ClassificationPredictor setting the task to 'classify'."""
|
|
32
32
|
super().__init__(cfg, overrides, _callbacks)
|
|
33
|
-
self.args.task =
|
|
34
|
-
self._legacy_transform_name =
|
|
33
|
+
self.args.task = "classify"
|
|
34
|
+
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
|
35
35
|
|
|
36
36
|
def preprocess(self, img):
|
|
37
37
|
"""Converts input image to model-compatible data type."""
|
|
38
38
|
if not isinstance(img, torch.Tensor):
|
|
39
|
-
is_legacy_transform = any(
|
|
40
|
-
|
|
39
|
+
is_legacy_transform = any(
|
|
40
|
+
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
|
41
|
+
)
|
|
41
42
|
if is_legacy_transform: # to handle legacy transforms
|
|
42
43
|
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
|
43
44
|
else:
|
|
44
|
-
img = torch.stack(
|
|
45
|
-
|
|
45
|
+
img = torch.stack(
|
|
46
|
+
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
|
47
|
+
)
|
|
46
48
|
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
|
47
49
|
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
|
48
50
|
|
|
@@ -33,23 +33,23 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
33
33
|
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
|
34
34
|
if overrides is None:
|
|
35
35
|
overrides = {}
|
|
36
|
-
overrides[
|
|
37
|
-
if overrides.get(
|
|
38
|
-
overrides[
|
|
36
|
+
overrides["task"] = "classify"
|
|
37
|
+
if overrides.get("imgsz") is None:
|
|
38
|
+
overrides["imgsz"] = 224
|
|
39
39
|
super().__init__(cfg, overrides, _callbacks)
|
|
40
40
|
|
|
41
41
|
def set_model_attributes(self):
|
|
42
42
|
"""Set the YOLO model's class names from the loaded dataset."""
|
|
43
|
-
self.model.names = self.data[
|
|
43
|
+
self.model.names = self.data["names"]
|
|
44
44
|
|
|
45
45
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
46
46
|
"""Returns a modified PyTorch model configured for training YOLO."""
|
|
47
|
-
model = ClassificationModel(cfg, nc=self.data[
|
|
47
|
+
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
48
48
|
if weights:
|
|
49
49
|
model.load(weights)
|
|
50
50
|
|
|
51
51
|
for m in model.modules():
|
|
52
|
-
if not self.args.pretrained and hasattr(m,
|
|
52
|
+
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
|
53
53
|
m.reset_parameters()
|
|
54
54
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
|
55
55
|
m.p = self.args.dropout # set dropout
|
|
@@ -64,32 +64,32 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
64
64
|
|
|
65
65
|
model, ckpt = str(self.model), None
|
|
66
66
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
|
67
|
-
if model.endswith(
|
|
68
|
-
self.model, ckpt = attempt_load_one_weight(model, device=
|
|
67
|
+
if model.endswith(".pt"):
|
|
68
|
+
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
|
69
69
|
for p in self.model.parameters():
|
|
70
70
|
p.requires_grad = True # for training
|
|
71
|
-
elif model.split(
|
|
71
|
+
elif model.split(".")[-1] in ("yaml", "yml"):
|
|
72
72
|
self.model = self.get_model(cfg=model)
|
|
73
73
|
elif model in torchvision.models.__dict__:
|
|
74
|
-
self.model = torchvision.models.__dict__[model](weights=
|
|
74
|
+
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
|
75
75
|
else:
|
|
76
|
-
FileNotFoundError(f
|
|
77
|
-
ClassificationModel.reshape_outputs(self.model, self.data[
|
|
76
|
+
FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
|
|
77
|
+
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
|
78
78
|
|
|
79
79
|
return ckpt
|
|
80
80
|
|
|
81
|
-
def build_dataset(self, img_path, mode=
|
|
81
|
+
def build_dataset(self, img_path, mode="train", batch=None):
|
|
82
82
|
"""Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
|
|
83
|
-
return ClassificationDataset(root=img_path, args=self.args, augment=mode ==
|
|
83
|
+
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
84
84
|
|
|
85
|
-
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode=
|
|
85
|
+
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
|
86
86
|
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
|
87
87
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
|
88
88
|
dataset = self.build_dataset(dataset_path, mode)
|
|
89
89
|
|
|
90
90
|
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
|
91
91
|
# Attach inference transforms
|
|
92
|
-
if mode !=
|
|
92
|
+
if mode != "train":
|
|
93
93
|
if is_parallel(self.model):
|
|
94
94
|
self.model.module.transforms = loader.dataset.torch_transforms
|
|
95
95
|
else:
|
|
@@ -98,27 +98,32 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
98
98
|
|
|
99
99
|
def preprocess_batch(self, batch):
|
|
100
100
|
"""Preprocesses a batch of images and classes."""
|
|
101
|
-
batch[
|
|
102
|
-
batch[
|
|
101
|
+
batch["img"] = batch["img"].to(self.device)
|
|
102
|
+
batch["cls"] = batch["cls"].to(self.device)
|
|
103
103
|
return batch
|
|
104
104
|
|
|
105
105
|
def progress_string(self):
|
|
106
106
|
"""Returns a formatted string showing training progress."""
|
|
107
|
-
return (
|
|
108
|
-
|
|
107
|
+
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
|
108
|
+
"Epoch",
|
|
109
|
+
"GPU_mem",
|
|
110
|
+
*self.loss_names,
|
|
111
|
+
"Instances",
|
|
112
|
+
"Size",
|
|
113
|
+
)
|
|
109
114
|
|
|
110
115
|
def get_validator(self):
|
|
111
116
|
"""Returns an instance of ClassificationValidator for validation."""
|
|
112
|
-
self.loss_names = [
|
|
117
|
+
self.loss_names = ["loss"]
|
|
113
118
|
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
|
114
119
|
|
|
115
|
-
def label_loss_items(self, loss_items=None, prefix=
|
|
120
|
+
def label_loss_items(self, loss_items=None, prefix="train"):
|
|
116
121
|
"""
|
|
117
122
|
Returns a loss dict with labelled training loss items tensor.
|
|
118
123
|
|
|
119
124
|
Not needed for classification but necessary for segmentation & detection
|
|
120
125
|
"""
|
|
121
|
-
keys = [f
|
|
126
|
+
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
|
122
127
|
if loss_items is None:
|
|
123
128
|
return keys
|
|
124
129
|
loss_items = [round(float(loss_items), 5)]
|
|
@@ -134,19 +139,20 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
134
139
|
if f.exists():
|
|
135
140
|
strip_optimizer(f) # strip optimizers
|
|
136
141
|
if f is self.best:
|
|
137
|
-
LOGGER.info(f
|
|
142
|
+
LOGGER.info(f"\nValidating {f}...")
|
|
138
143
|
self.validator.args.data = self.args.data
|
|
139
144
|
self.validator.args.plots = self.args.plots
|
|
140
145
|
self.metrics = self.validator(model=f)
|
|
141
|
-
self.metrics.pop(
|
|
142
|
-
self.run_callbacks(
|
|
146
|
+
self.metrics.pop("fitness", None)
|
|
147
|
+
self.run_callbacks("on_fit_epoch_end")
|
|
143
148
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
|
144
149
|
|
|
145
150
|
def plot_training_samples(self, batch, ni):
|
|
146
151
|
"""Plots training samples with their annotations."""
|
|
147
152
|
plot_images(
|
|
148
|
-
images=batch[
|
|
149
|
-
batch_idx=torch.arange(len(batch[
|
|
150
|
-
cls=batch[
|
|
151
|
-
fname=self.save_dir / f
|
|
152
|
-
on_plot=self.on_plot
|
|
153
|
+
images=batch["img"],
|
|
154
|
+
batch_idx=torch.arange(len(batch["img"])),
|
|
155
|
+
cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
|
|
156
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
157
|
+
on_plot=self.on_plot,
|
|
158
|
+
)
|