ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +13 -19
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +64 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/models/utils/loss.py
CHANGED
@@ -12,21 +12,22 @@ from .ops import HungarianMatcher
|
|
12
12
|
|
13
13
|
class DETRLoss(nn.Module):
|
14
14
|
"""
|
15
|
-
DETR (DEtection TRansformer) Loss class
|
16
|
-
|
17
|
-
losses
|
15
|
+
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
16
|
+
|
17
|
+
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
18
|
+
DETR object detection model.
|
18
19
|
|
19
20
|
Attributes:
|
20
|
-
nc (int):
|
21
|
-
loss_gain (
|
21
|
+
nc (int): Number of classes.
|
22
|
+
loss_gain (Dict): Coefficients for different loss components.
|
22
23
|
aux_loss (bool): Whether to compute auxiliary losses.
|
23
|
-
use_fl (bool):
|
24
|
-
use_vfl (bool):
|
25
|
-
use_uni_match (bool): Whether to use a fixed layer
|
26
|
-
uni_match_ind (int):
|
24
|
+
use_fl (bool): Whether to use FocalLoss.
|
25
|
+
use_vfl (bool): Whether to use VarifocalLoss.
|
26
|
+
use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
|
27
|
+
uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
|
27
28
|
matcher (HungarianMatcher): Object to compute matching cost and indices.
|
28
|
-
fl (FocalLoss
|
29
|
-
vfl (VarifocalLoss
|
29
|
+
fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
|
30
|
+
vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
|
30
31
|
device (torch.device): Device on which tensors are stored.
|
31
32
|
"""
|
32
33
|
|
@@ -36,16 +37,16 @@ class DETRLoss(nn.Module):
|
|
36
37
|
"""
|
37
38
|
Initialize DETR loss function with customizable components and gains.
|
38
39
|
|
39
|
-
Uses default loss_gain if not provided. Initializes HungarianMatcher with
|
40
|
-
|
40
|
+
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
41
|
+
losses and various loss types.
|
41
42
|
|
42
43
|
Args:
|
43
44
|
nc (int): Number of classes.
|
44
|
-
loss_gain (
|
45
|
-
aux_loss (bool):
|
46
|
-
use_fl (bool):
|
47
|
-
use_vfl (bool):
|
48
|
-
use_uni_match (bool):
|
45
|
+
loss_gain (Dict): Coefficients for different loss components.
|
46
|
+
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
|
47
|
+
use_fl (bool): Whether to use FocalLoss.
|
48
|
+
use_vfl (bool): Whether to use VarifocalLoss.
|
49
|
+
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
|
49
50
|
uni_match_ind (int): Index of fixed layer for uni_match.
|
50
51
|
"""
|
51
52
|
super().__init__()
|
@@ -64,7 +65,7 @@ class DETRLoss(nn.Module):
|
|
64
65
|
self.device = None
|
65
66
|
|
66
67
|
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
|
67
|
-
"""
|
68
|
+
"""Compute classification loss based on predictions, target values, and ground truth scores."""
|
68
69
|
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
69
70
|
name_class = f"loss_class{postfix}"
|
70
71
|
bs, nq = pred_scores.shape[:2]
|
@@ -86,7 +87,7 @@ class DETRLoss(nn.Module):
|
|
86
87
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
87
88
|
|
88
89
|
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
|
89
|
-
"""
|
90
|
+
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes."""
|
90
91
|
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
91
92
|
name_bbox = f"loss_bbox{postfix}"
|
92
93
|
name_giou = f"loss_giou{postfix}"
|
@@ -146,7 +147,23 @@ class DETRLoss(nn.Module):
|
|
146
147
|
masks=None,
|
147
148
|
gt_mask=None,
|
148
149
|
):
|
149
|
-
"""
|
150
|
+
"""
|
151
|
+
Get auxiliary losses for intermediate decoder layers.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
155
|
+
pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
|
156
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
157
|
+
gt_cls (torch.Tensor): Ground truth classes.
|
158
|
+
gt_groups (List[int]): Number of ground truths per image.
|
159
|
+
match_indices (List[tuple], optional): Pre-computed matching indices.
|
160
|
+
postfix (str): String to append to loss names.
|
161
|
+
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
162
|
+
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
(Dict): Dictionary of auxiliary losses.
|
166
|
+
"""
|
150
167
|
# NOTE: loss class, bbox, giou, mask, dice
|
151
168
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
152
169
|
if match_indices is None and self.use_uni_match:
|
@@ -192,14 +209,32 @@ class DETRLoss(nn.Module):
|
|
192
209
|
|
193
210
|
@staticmethod
|
194
211
|
def _get_index(match_indices):
|
195
|
-
"""
|
212
|
+
"""
|
213
|
+
Extract batch indices, source indices, and destination indices from match indices.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
match_indices (List[tuple]): List of tuples containing matched indices.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
(tuple): Tuple containing (batch_idx, src_idx) and dst_idx.
|
220
|
+
"""
|
196
221
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
197
222
|
src_idx = torch.cat([src for (src, _) in match_indices])
|
198
223
|
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
199
224
|
return (batch_idx, src_idx), dst_idx
|
200
225
|
|
201
226
|
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
202
|
-
"""
|
227
|
+
"""
|
228
|
+
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
232
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
233
|
+
match_indices (List[tuple]): List of tuples containing matched indices.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
(tuple): Tuple containing assigned predictions and ground truths.
|
237
|
+
"""
|
203
238
|
pred_assigned = torch.cat(
|
204
239
|
[
|
205
240
|
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
@@ -226,7 +261,23 @@ class DETRLoss(nn.Module):
|
|
226
261
|
postfix="",
|
227
262
|
match_indices=None,
|
228
263
|
):
|
229
|
-
"""
|
264
|
+
"""
|
265
|
+
Calculate losses for a single prediction layer.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
269
|
+
pred_scores (torch.Tensor): Predicted class scores.
|
270
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
271
|
+
gt_cls (torch.Tensor): Ground truth classes.
|
272
|
+
gt_groups (List[int]): Number of ground truths per image.
|
273
|
+
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
274
|
+
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
275
|
+
postfix (str): String to append to loss names.
|
276
|
+
match_indices (List[tuple], optional): Pre-computed matching indices.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
(Dict): Dictionary of losses.
|
280
|
+
"""
|
230
281
|
if match_indices is None:
|
231
282
|
match_indices = self.matcher(
|
232
283
|
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
@@ -256,7 +307,7 @@ class DETRLoss(nn.Module):
|
|
256
307
|
Args:
|
257
308
|
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
|
258
309
|
pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
|
259
|
-
batch (
|
310
|
+
batch (Dict): Batch information containing:
|
260
311
|
cls (torch.Tensor): Ground truth classes, shape [num_gts].
|
261
312
|
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
|
262
313
|
gt_groups (List[int]): Number of ground truths for each image in the batch.
|
@@ -264,9 +315,9 @@ class DETRLoss(nn.Module):
|
|
264
315
|
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
265
316
|
|
266
317
|
Returns:
|
267
|
-
(
|
318
|
+
(Dict): Computed losses, including main and auxiliary (if enabled).
|
268
319
|
|
269
|
-
|
320
|
+
Notes:
|
270
321
|
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
271
322
|
self.aux_loss is True.
|
272
323
|
"""
|
@@ -298,17 +349,17 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
298
349
|
|
299
350
|
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
300
351
|
"""
|
301
|
-
Forward pass to compute
|
352
|
+
Forward pass to compute detection loss with optional denoising loss.
|
302
353
|
|
303
354
|
Args:
|
304
|
-
preds (tuple):
|
305
|
-
batch (
|
306
|
-
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
307
|
-
dn_scores (torch.Tensor, optional): Denoising scores.
|
308
|
-
dn_meta (
|
355
|
+
preds (tuple): Tuple containing predicted bounding boxes and scores.
|
356
|
+
batch (Dict): Batch data containing ground truth information.
|
357
|
+
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
358
|
+
dn_scores (torch.Tensor, optional): Denoising scores.
|
359
|
+
dn_meta (Dict, optional): Metadata for denoising.
|
309
360
|
|
310
361
|
Returns:
|
311
|
-
(
|
362
|
+
(Dict): Dictionary containing total loss and denoising loss if applicable.
|
312
363
|
"""
|
313
364
|
pred_bboxes, pred_scores = preds
|
314
365
|
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
@@ -333,12 +384,12 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
333
384
|
@staticmethod
|
334
385
|
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
|
335
386
|
"""
|
336
|
-
Get
|
387
|
+
Get match indices for denoising.
|
337
388
|
|
338
389
|
Args:
|
339
390
|
dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
|
340
391
|
dn_num_group (int): Number of denoising groups.
|
341
|
-
gt_groups (List[int]): List of integers representing
|
392
|
+
gt_groups (List[int]): List of integers representing number of ground truths per image.
|
342
393
|
|
343
394
|
Returns:
|
344
395
|
(List[tuple]): List of tuples containing matched indices for denoising.
|
ultralytics/models/utils/ops.py
CHANGED
@@ -18,7 +18,7 @@ class HungarianMatcher(nn.Module):
|
|
18
18
|
function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
|
19
19
|
|
20
20
|
Attributes:
|
21
|
-
cost_gain (
|
21
|
+
cost_gain (Dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
22
22
|
use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
|
23
23
|
with_mask (bool): Indicates whether the model makes mask predictions.
|
24
24
|
num_sample_points (int): The number of sample points used in mask cost calculation.
|
@@ -26,13 +26,12 @@ class HungarianMatcher(nn.Module):
|
|
26
26
|
gamma (float): The gamma factor in Focal Loss calculation.
|
27
27
|
|
28
28
|
Methods:
|
29
|
-
forward
|
30
|
-
|
31
|
-
_cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
|
29
|
+
forward: Computes the assignment between predictions and ground truths for a batch.
|
30
|
+
_cost_mask: Computes the mask cost and dice cost if masks are predicted.
|
32
31
|
"""
|
33
32
|
|
34
33
|
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
35
|
-
"""
|
34
|
+
"""Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
|
36
35
|
super().__init__()
|
37
36
|
if cost_gain is None:
|
38
37
|
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
@@ -45,24 +44,21 @@ class HungarianMatcher(nn.Module):
|
|
45
44
|
|
46
45
|
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
47
46
|
"""
|
48
|
-
Forward pass for HungarianMatcher.
|
49
|
-
|
50
|
-
predictions and ground truth based on these costs.
|
47
|
+
Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
|
48
|
+
matching between predictions and ground truth based on these costs.
|
51
49
|
|
52
50
|
Args:
|
53
|
-
pred_bboxes (Tensor): Predicted bounding boxes with shape
|
54
|
-
pred_scores (Tensor): Predicted scores with shape
|
55
|
-
gt_cls (torch.Tensor): Ground truth classes with shape
|
56
|
-
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape
|
51
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
52
|
+
pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
|
53
|
+
gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
|
54
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
57
55
|
gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
|
58
56
|
each image.
|
59
|
-
masks (Tensor, optional): Predicted masks with shape
|
60
|
-
|
61
|
-
gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
|
62
|
-
Defaults to None.
|
57
|
+
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
58
|
+
gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
|
63
59
|
|
64
60
|
Returns:
|
65
|
-
(List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
|
61
|
+
(List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
|
66
62
|
- index_i is the tensor of indices of the selected predictions (in order)
|
67
63
|
- index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
|
68
64
|
For each batch element, it holds:
|
@@ -74,10 +70,10 @@ class HungarianMatcher(nn.Module):
|
|
74
70
|
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
75
71
|
|
76
72
|
# We flatten to compute the cost matrices in a batch
|
77
|
-
#
|
73
|
+
# (batch_size * num_queries, num_classes)
|
78
74
|
pred_scores = pred_scores.detach().view(-1, nc)
|
79
75
|
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
80
|
-
#
|
76
|
+
# (batch_size * num_queries, 4)
|
81
77
|
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
82
78
|
|
83
79
|
# Compute the classification cost
|
@@ -151,26 +147,25 @@ def get_cdn_group(
|
|
151
147
|
batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
|
152
148
|
):
|
153
149
|
"""
|
154
|
-
Get contrastive denoising training group
|
155
|
-
and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
|
156
|
-
and returns the modified labels, bounding boxes, attention mask and meta information.
|
150
|
+
Get contrastive denoising training group with positive and negative samples from ground truths.
|
157
151
|
|
158
152
|
Args:
|
159
|
-
batch (
|
160
|
-
(torch.Tensor with shape
|
153
|
+
batch (Dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
|
154
|
+
(torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
|
161
155
|
indicating the number of gts of each image.
|
162
156
|
num_classes (int): Number of classes.
|
163
157
|
num_queries (int): Number of queries.
|
164
158
|
class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
|
165
|
-
num_dn (int, optional): Number of denoising
|
166
|
-
cls_noise_ratio (float, optional): Noise ratio for class labels.
|
167
|
-
box_noise_scale (float, optional): Noise scale for bounding box coordinates.
|
168
|
-
training (bool, optional): If it's in training mode.
|
159
|
+
num_dn (int, optional): Number of denoising queries.
|
160
|
+
cls_noise_ratio (float, optional): Noise ratio for class labels.
|
161
|
+
box_noise_scale (float, optional): Noise scale for bounding box coordinates.
|
162
|
+
training (bool, optional): If it's in training mode.
|
169
163
|
|
170
164
|
Returns:
|
171
|
-
(
|
172
|
-
|
173
|
-
|
165
|
+
padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
|
166
|
+
padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
|
167
|
+
attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
|
168
|
+
dn_meta (Optional[Dict]): Meta information for denoising.
|
174
169
|
"""
|
175
170
|
if (not training) or num_dn <= 0 or batch is None:
|
176
171
|
return None, None, None, None
|
@@ -13,6 +13,17 @@ class ClassificationPredictor(BasePredictor):
|
|
13
13
|
"""
|
14
14
|
A class extending the BasePredictor class for prediction based on a classification model.
|
15
15
|
|
16
|
+
This predictor handles the specific requirements of classification models, including preprocessing images
|
17
|
+
and postprocessing predictions to generate classification results.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
args (Dict): Configuration arguments for the predictor.
|
21
|
+
_legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
|
22
|
+
|
23
|
+
Methods:
|
24
|
+
preprocess: Convert input images to model-compatible format.
|
25
|
+
postprocess: Process model predictions into Results objects.
|
26
|
+
|
16
27
|
Notes:
|
17
28
|
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
18
29
|
|
@@ -25,13 +36,13 @@ class ClassificationPredictor(BasePredictor):
|
|
25
36
|
"""
|
26
37
|
|
27
38
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
28
|
-
"""
|
39
|
+
"""Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'."""
|
29
40
|
super().__init__(cfg, overrides, _callbacks)
|
30
41
|
self.args.task = "classify"
|
31
42
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
32
43
|
|
33
44
|
def preprocess(self, img):
|
34
|
-
"""
|
45
|
+
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
35
46
|
if not isinstance(img, torch.Tensor):
|
36
47
|
is_legacy_transform = any(
|
37
48
|
self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
|
@@ -46,7 +57,17 @@ class ClassificationPredictor(BasePredictor):
|
|
46
57
|
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
47
58
|
|
48
59
|
def postprocess(self, preds, img, orig_imgs):
|
49
|
-
"""
|
60
|
+
"""
|
61
|
+
Process predictions to return Results objects with classification probabilities.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
preds (torch.Tensor): Raw predictions from the model.
|
65
|
+
img (torch.Tensor): Input images after preprocessing.
|
66
|
+
orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
(List[Results]): List of Results objects containing classification results for each image.
|
70
|
+
"""
|
50
71
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
51
72
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
52
73
|
|
@@ -17,8 +17,28 @@ class ClassificationTrainer(BaseTrainer):
|
|
17
17
|
"""
|
18
18
|
A class extending the BaseTrainer class for training based on a classification model.
|
19
19
|
|
20
|
-
|
21
|
-
|
20
|
+
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
21
|
+
and torchvision models.
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
model (ClassificationModel): The classification model to be trained.
|
25
|
+
data (Dict): Dictionary containing dataset information including class names and number of classes.
|
26
|
+
loss_names (List[str]): Names of the loss functions used during training.
|
27
|
+
validator (ClassificationValidator): Validator instance for model evaluation.
|
28
|
+
|
29
|
+
Methods:
|
30
|
+
set_model_attributes: Set the model's class names from the loaded dataset.
|
31
|
+
get_model: Return a modified PyTorch model configured for training.
|
32
|
+
setup_model: Load, create or download model for classification.
|
33
|
+
build_dataset: Create a ClassificationDataset instance.
|
34
|
+
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
35
|
+
preprocess_batch: Preprocess a batch of images and classes.
|
36
|
+
progress_string: Return a formatted string showing training progress.
|
37
|
+
get_validator: Return an instance of ClassificationValidator.
|
38
|
+
label_loss_items: Return a loss dict with labelled training loss items.
|
39
|
+
plot_metrics: Plot metrics from a CSV file.
|
40
|
+
final_eval: Evaluate trained model and save validation results.
|
41
|
+
plot_training_samples: Plot training samples with their annotations.
|
22
42
|
|
23
43
|
Examples:
|
24
44
|
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
@@ -41,7 +61,17 @@ class ClassificationTrainer(BaseTrainer):
|
|
41
61
|
self.model.names = self.data["names"]
|
42
62
|
|
43
63
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
44
|
-
"""
|
64
|
+
"""
|
65
|
+
Return a modified PyTorch model configured for training YOLO.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
cfg (Any): Model configuration.
|
69
|
+
weights (Any): Pre-trained model weights.
|
70
|
+
verbose (bool): Whether to display model information.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
(ClassificationModel): Configured PyTorch model for classification.
|
74
|
+
"""
|
45
75
|
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
46
76
|
if weights:
|
47
77
|
model.load(weights)
|
@@ -56,7 +86,12 @@ class ClassificationTrainer(BaseTrainer):
|
|
56
86
|
return model
|
57
87
|
|
58
88
|
def setup_model(self):
|
59
|
-
"""
|
89
|
+
"""
|
90
|
+
Load, create or download model for classification tasks.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
(Any): Model checkpoint if applicable, otherwise None.
|
94
|
+
"""
|
60
95
|
import torchvision # scope for faster 'import ultralytics'
|
61
96
|
|
62
97
|
if str(self.model) in torchvision.models.__dict__:
|
@@ -70,11 +105,32 @@ class ClassificationTrainer(BaseTrainer):
|
|
70
105
|
return ckpt
|
71
106
|
|
72
107
|
def build_dataset(self, img_path, mode="train", batch=None):
|
73
|
-
"""
|
108
|
+
"""
|
109
|
+
Create a ClassificationDataset instance given an image path and mode.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
img_path (str): Path to the dataset images.
|
113
|
+
mode (str): Dataset mode ('train', 'val', or 'test').
|
114
|
+
batch (Any): Batch information (unused in this implementation).
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
(ClassificationDataset): Dataset for the specified mode.
|
118
|
+
"""
|
74
119
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
75
120
|
|
76
121
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
77
|
-
"""
|
122
|
+
"""
|
123
|
+
Return PyTorch DataLoader with transforms to preprocess images.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
dataset_path (str): Path to the dataset.
|
127
|
+
batch_size (int): Number of images per batch.
|
128
|
+
rank (int): Process rank for distributed training.
|
129
|
+
mode (str): 'train', 'val', or 'test' mode.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
133
|
+
"""
|
78
134
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
79
135
|
dataset = self.build_dataset(dataset_path, mode)
|
80
136
|
|
@@ -112,9 +168,14 @@ class ClassificationTrainer(BaseTrainer):
|
|
112
168
|
|
113
169
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
114
170
|
"""
|
115
|
-
|
171
|
+
Return a loss dict with labelled training loss items tensor.
|
116
172
|
|
117
|
-
|
173
|
+
Args:
|
174
|
+
loss_items (torch.Tensor, optional): Loss tensor items.
|
175
|
+
prefix (str): Prefix to prepend to loss names.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
(Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
|
118
179
|
"""
|
119
180
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
120
181
|
if loss_items is None:
|
@@ -123,7 +184,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
123
184
|
return dict(zip(keys, loss_items))
|
124
185
|
|
125
186
|
def plot_metrics(self):
|
126
|
-
"""
|
187
|
+
"""Plot metrics from a CSV file."""
|
127
188
|
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
128
189
|
|
129
190
|
def final_eval(self):
|
@@ -140,7 +201,13 @@ class ClassificationTrainer(BaseTrainer):
|
|
140
201
|
self.run_callbacks("on_fit_epoch_end")
|
141
202
|
|
142
203
|
def plot_training_samples(self, batch, ni):
|
143
|
-
"""
|
204
|
+
"""
|
205
|
+
Plot training samples with their annotations.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
|
209
|
+
ni (int): Number of iterations.
|
210
|
+
"""
|
144
211
|
plot_images(
|
145
212
|
images=batch["img"],
|
146
213
|
batch_idx=torch.arange(len(batch["img"])),
|