dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/models/utils/loss.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
3
7
|
import torch
|
|
4
8
|
import torch.nn as nn
|
|
5
9
|
import torch.nn.functional as F
|
|
@@ -11,15 +15,14 @@ from .ops import HungarianMatcher
|
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
class DETRLoss(nn.Module):
|
|
14
|
-
"""
|
|
15
|
-
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
18
|
+
"""DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
|
16
19
|
|
|
17
|
-
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
|
18
|
-
|
|
20
|
+
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
|
|
21
|
+
object detection model.
|
|
19
22
|
|
|
20
23
|
Attributes:
|
|
21
24
|
nc (int): Number of classes.
|
|
22
|
-
loss_gain (dict): Coefficients for different loss components.
|
|
25
|
+
loss_gain (dict[str, float]): Coefficients for different loss components.
|
|
23
26
|
aux_loss (bool): Whether to compute auxiliary losses.
|
|
24
27
|
use_fl (bool): Whether to use FocalLoss.
|
|
25
28
|
use_vfl (bool): Whether to use VarifocalLoss.
|
|
@@ -33,32 +36,31 @@ class DETRLoss(nn.Module):
|
|
|
33
36
|
|
|
34
37
|
def __init__(
|
|
35
38
|
self,
|
|
36
|
-
nc=80,
|
|
37
|
-
loss_gain=None,
|
|
38
|
-
aux_loss=True,
|
|
39
|
-
use_fl=True,
|
|
40
|
-
use_vfl=False,
|
|
41
|
-
use_uni_match=False,
|
|
42
|
-
uni_match_ind=0,
|
|
43
|
-
gamma=1.5,
|
|
44
|
-
alpha=0.25,
|
|
39
|
+
nc: int = 80,
|
|
40
|
+
loss_gain: dict[str, float] | None = None,
|
|
41
|
+
aux_loss: bool = True,
|
|
42
|
+
use_fl: bool = True,
|
|
43
|
+
use_vfl: bool = False,
|
|
44
|
+
use_uni_match: bool = False,
|
|
45
|
+
uni_match_ind: int = 0,
|
|
46
|
+
gamma: float = 1.5,
|
|
47
|
+
alpha: float = 0.25,
|
|
45
48
|
):
|
|
46
|
-
"""
|
|
47
|
-
Initialize DETR loss function with customizable components and gains.
|
|
49
|
+
"""Initialize DETR loss function with customizable components and gains.
|
|
48
50
|
|
|
49
51
|
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
|
50
52
|
losses and various loss types.
|
|
51
53
|
|
|
52
54
|
Args:
|
|
53
55
|
nc (int): Number of classes.
|
|
54
|
-
loss_gain (dict): Coefficients for different loss components.
|
|
56
|
+
loss_gain (dict[str, float], optional): Coefficients for different loss components.
|
|
55
57
|
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
|
|
56
58
|
use_fl (bool): Whether to use FocalLoss.
|
|
57
59
|
use_vfl (bool): Whether to use VarifocalLoss.
|
|
58
60
|
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
|
|
59
61
|
uni_match_ind (int): Index of fixed layer for uni_match.
|
|
60
62
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
|
61
|
-
alpha (float
|
|
63
|
+
alpha (float): The balancing factor used to address class imbalance.
|
|
62
64
|
"""
|
|
63
65
|
super().__init__()
|
|
64
66
|
|
|
@@ -75,19 +77,20 @@ class DETRLoss(nn.Module):
|
|
|
75
77
|
self.uni_match_ind = uni_match_ind
|
|
76
78
|
self.device = None
|
|
77
79
|
|
|
78
|
-
def _get_loss_class(
|
|
79
|
-
""
|
|
80
|
-
|
|
80
|
+
def _get_loss_class(
|
|
81
|
+
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
|
82
|
+
) -> dict[str, torch.Tensor]:
|
|
83
|
+
"""Compute classification loss based on predictions, target values, and ground truth scores.
|
|
81
84
|
|
|
82
85
|
Args:
|
|
83
|
-
pred_scores (torch.Tensor): Predicted class scores with shape (
|
|
84
|
-
targets (torch.Tensor): Target class indices with shape (
|
|
85
|
-
gt_scores (torch.Tensor): Ground truth confidence scores with shape (
|
|
86
|
+
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
|
87
|
+
targets (torch.Tensor): Target class indices with shape (B, N).
|
|
88
|
+
gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
|
|
86
89
|
num_gts (int): Number of ground truth objects.
|
|
87
90
|
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
|
|
88
91
|
|
|
89
92
|
Returns:
|
|
90
|
-
|
|
93
|
+
(dict[str, torch.Tensor]): Dictionary containing classification loss value.
|
|
91
94
|
|
|
92
95
|
Notes:
|
|
93
96
|
The function supports different classification loss types:
|
|
@@ -115,22 +118,20 @@ class DETRLoss(nn.Module):
|
|
|
115
118
|
|
|
116
119
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
|
117
120
|
|
|
118
|
-
def _get_loss_bbox(
|
|
119
|
-
""
|
|
120
|
-
|
|
121
|
+
def _get_loss_bbox(
|
|
122
|
+
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
|
123
|
+
) -> dict[str, torch.Tensor]:
|
|
124
|
+
"""Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
|
121
125
|
|
|
122
126
|
Args:
|
|
123
|
-
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (
|
|
124
|
-
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4)
|
|
125
|
-
|
|
126
|
-
postfix (str): String to append to the loss names for identification in multi-loss scenarios.
|
|
127
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
|
128
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
|
|
129
|
+
postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
|
|
127
130
|
|
|
128
131
|
Returns:
|
|
129
|
-
|
|
130
|
-
- loss_bbox{postfix}
|
|
131
|
-
|
|
132
|
-
- loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
|
|
133
|
-
scaled by the giou loss gain.
|
|
132
|
+
(dict[str, torch.Tensor]): Dictionary containing:
|
|
133
|
+
- loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
|
|
134
|
+
- loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
|
|
134
135
|
|
|
135
136
|
Notes:
|
|
136
137
|
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
|
|
@@ -184,32 +185,31 @@ class DETRLoss(nn.Module):
|
|
|
184
185
|
|
|
185
186
|
def _get_loss_aux(
|
|
186
187
|
self,
|
|
187
|
-
pred_bboxes,
|
|
188
|
-
pred_scores,
|
|
189
|
-
gt_bboxes,
|
|
190
|
-
gt_cls,
|
|
191
|
-
gt_groups,
|
|
192
|
-
match_indices=None,
|
|
193
|
-
postfix="",
|
|
194
|
-
masks=None,
|
|
195
|
-
gt_mask=None,
|
|
196
|
-
):
|
|
197
|
-
"""
|
|
198
|
-
Get auxiliary losses for intermediate decoder layers.
|
|
188
|
+
pred_bboxes: torch.Tensor,
|
|
189
|
+
pred_scores: torch.Tensor,
|
|
190
|
+
gt_bboxes: torch.Tensor,
|
|
191
|
+
gt_cls: torch.Tensor,
|
|
192
|
+
gt_groups: list[int],
|
|
193
|
+
match_indices: list[tuple] | None = None,
|
|
194
|
+
postfix: str = "",
|
|
195
|
+
masks: torch.Tensor | None = None,
|
|
196
|
+
gt_mask: torch.Tensor | None = None,
|
|
197
|
+
) -> dict[str, torch.Tensor]:
|
|
198
|
+
"""Get auxiliary losses for intermediate decoder layers.
|
|
199
199
|
|
|
200
200
|
Args:
|
|
201
201
|
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
|
202
202
|
pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
|
|
203
203
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
|
204
204
|
gt_cls (torch.Tensor): Ground truth classes.
|
|
205
|
-
gt_groups (
|
|
206
|
-
match_indices (
|
|
207
|
-
postfix (str): String to append to loss names.
|
|
205
|
+
gt_groups (list[int]): Number of ground truths per image.
|
|
206
|
+
match_indices (list[tuple], optional): Pre-computed matching indices.
|
|
207
|
+
postfix (str, optional): String to append to loss names.
|
|
208
208
|
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
|
209
209
|
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
|
210
210
|
|
|
211
211
|
Returns:
|
|
212
|
-
(dict): Dictionary of auxiliary losses.
|
|
212
|
+
(dict[str, torch.Tensor]): Dictionary of auxiliary losses.
|
|
213
213
|
"""
|
|
214
214
|
# NOTE: loss class, bbox, giou, mask, dice
|
|
215
215
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
|
@@ -255,32 +255,34 @@ class DETRLoss(nn.Module):
|
|
|
255
255
|
return loss
|
|
256
256
|
|
|
257
257
|
@staticmethod
|
|
258
|
-
def _get_index(match_indices):
|
|
259
|
-
"""
|
|
260
|
-
Extract batch indices, source indices, and destination indices from match indices.
|
|
258
|
+
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
259
|
+
"""Extract batch indices, source indices, and destination indices from match indices.
|
|
261
260
|
|
|
262
261
|
Args:
|
|
263
|
-
match_indices (
|
|
262
|
+
match_indices (list[tuple]): List of tuples containing matched indices.
|
|
264
263
|
|
|
265
264
|
Returns:
|
|
266
|
-
(tuple): Tuple containing (batch_idx, src_idx)
|
|
265
|
+
batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
|
|
266
|
+
dst_idx (torch.Tensor): Destination indices.
|
|
267
267
|
"""
|
|
268
268
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
|
269
269
|
src_idx = torch.cat([src for (src, _) in match_indices])
|
|
270
270
|
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
|
271
271
|
return (batch_idx, src_idx), dst_idx
|
|
272
272
|
|
|
273
|
-
def _get_assigned_bboxes(
|
|
274
|
-
|
|
275
|
-
|
|
273
|
+
def _get_assigned_bboxes(
|
|
274
|
+
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
|
275
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
276
|
+
"""Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
|
276
277
|
|
|
277
278
|
Args:
|
|
278
279
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
279
280
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
|
280
|
-
match_indices (
|
|
281
|
+
match_indices (list[tuple]): List of tuples containing matched indices.
|
|
281
282
|
|
|
282
283
|
Returns:
|
|
283
|
-
(
|
|
284
|
+
pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
|
|
285
|
+
gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
|
|
284
286
|
"""
|
|
285
287
|
pred_assigned = torch.cat(
|
|
286
288
|
[
|
|
@@ -298,32 +300,31 @@ class DETRLoss(nn.Module):
|
|
|
298
300
|
|
|
299
301
|
def _get_loss(
|
|
300
302
|
self,
|
|
301
|
-
pred_bboxes,
|
|
302
|
-
pred_scores,
|
|
303
|
-
gt_bboxes,
|
|
304
|
-
gt_cls,
|
|
305
|
-
gt_groups,
|
|
306
|
-
masks=None,
|
|
307
|
-
gt_mask=None,
|
|
308
|
-
postfix="",
|
|
309
|
-
match_indices=None,
|
|
310
|
-
):
|
|
311
|
-
"""
|
|
312
|
-
Calculate losses for a single prediction layer.
|
|
303
|
+
pred_bboxes: torch.Tensor,
|
|
304
|
+
pred_scores: torch.Tensor,
|
|
305
|
+
gt_bboxes: torch.Tensor,
|
|
306
|
+
gt_cls: torch.Tensor,
|
|
307
|
+
gt_groups: list[int],
|
|
308
|
+
masks: torch.Tensor | None = None,
|
|
309
|
+
gt_mask: torch.Tensor | None = None,
|
|
310
|
+
postfix: str = "",
|
|
311
|
+
match_indices: list[tuple] | None = None,
|
|
312
|
+
) -> dict[str, torch.Tensor]:
|
|
313
|
+
"""Calculate losses for a single prediction layer.
|
|
313
314
|
|
|
314
315
|
Args:
|
|
315
316
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
|
316
317
|
pred_scores (torch.Tensor): Predicted class scores.
|
|
317
318
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
|
318
319
|
gt_cls (torch.Tensor): Ground truth classes.
|
|
319
|
-
gt_groups (
|
|
320
|
+
gt_groups (list[int]): Number of ground truths per image.
|
|
320
321
|
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
|
321
322
|
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
|
322
|
-
postfix (str): String to append to loss names.
|
|
323
|
-
match_indices (
|
|
323
|
+
postfix (str, optional): String to append to loss names.
|
|
324
|
+
match_indices (list[tuple], optional): Pre-computed matching indices.
|
|
324
325
|
|
|
325
326
|
Returns:
|
|
326
|
-
(dict): Dictionary of losses.
|
|
327
|
+
(dict[str, torch.Tensor]): Dictionary of losses.
|
|
327
328
|
"""
|
|
328
329
|
if match_indices is None:
|
|
329
330
|
match_indices = self.matcher(
|
|
@@ -347,22 +348,25 @@ class DETRLoss(nn.Module):
|
|
|
347
348
|
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
|
|
348
349
|
}
|
|
349
350
|
|
|
350
|
-
def forward(
|
|
351
|
-
|
|
352
|
-
|
|
351
|
+
def forward(
|
|
352
|
+
self,
|
|
353
|
+
pred_bboxes: torch.Tensor,
|
|
354
|
+
pred_scores: torch.Tensor,
|
|
355
|
+
batch: dict[str, Any],
|
|
356
|
+
postfix: str = "",
|
|
357
|
+
**kwargs: Any,
|
|
358
|
+
) -> dict[str, torch.Tensor]:
|
|
359
|
+
"""Calculate loss for predicted bounding boxes and scores.
|
|
353
360
|
|
|
354
361
|
Args:
|
|
355
|
-
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape
|
|
356
|
-
pred_scores (torch.Tensor): Predicted class scores, shape
|
|
357
|
-
batch (dict): Batch information containing
|
|
358
|
-
|
|
359
|
-
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
|
|
360
|
-
gt_groups (List[int]): Number of ground truths for each image in the batch.
|
|
361
|
-
postfix (str): Postfix for loss names.
|
|
362
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
|
363
|
+
pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
|
|
364
|
+
batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
|
|
365
|
+
postfix (str, optional): Postfix for loss names.
|
|
362
366
|
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
|
363
367
|
|
|
364
368
|
Returns:
|
|
365
|
-
(dict): Computed losses, including main and auxiliary (if enabled).
|
|
369
|
+
(dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
|
|
366
370
|
|
|
367
371
|
Notes:
|
|
368
372
|
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
|
@@ -387,26 +391,31 @@ class DETRLoss(nn.Module):
|
|
|
387
391
|
|
|
388
392
|
|
|
389
393
|
class RTDETRDetectionLoss(DETRLoss):
|
|
390
|
-
"""
|
|
391
|
-
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
394
|
+
"""Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
|
392
395
|
|
|
393
396
|
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
|
394
397
|
an additional denoising training loss when provided with denoising metadata.
|
|
395
398
|
"""
|
|
396
399
|
|
|
397
|
-
def forward(
|
|
398
|
-
|
|
399
|
-
|
|
400
|
+
def forward(
|
|
401
|
+
self,
|
|
402
|
+
preds: tuple[torch.Tensor, torch.Tensor],
|
|
403
|
+
batch: dict[str, Any],
|
|
404
|
+
dn_bboxes: torch.Tensor | None = None,
|
|
405
|
+
dn_scores: torch.Tensor | None = None,
|
|
406
|
+
dn_meta: dict[str, Any] | None = None,
|
|
407
|
+
) -> dict[str, torch.Tensor]:
|
|
408
|
+
"""Forward pass to compute detection loss with optional denoising loss.
|
|
400
409
|
|
|
401
410
|
Args:
|
|
402
|
-
preds (tuple): Tuple containing predicted bounding boxes and scores.
|
|
403
|
-
batch (dict): Batch data containing ground truth information.
|
|
411
|
+
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
|
412
|
+
batch (dict[str, Any]): Batch data containing ground truth information.
|
|
404
413
|
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
|
405
414
|
dn_scores (torch.Tensor, optional): Denoising scores.
|
|
406
|
-
dn_meta (dict, optional): Metadata for denoising.
|
|
415
|
+
dn_meta (dict[str, Any], optional): Metadata for denoising.
|
|
407
416
|
|
|
408
417
|
Returns:
|
|
409
|
-
(dict): Dictionary containing total loss and denoising loss if applicable.
|
|
418
|
+
(dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
|
|
410
419
|
"""
|
|
411
420
|
pred_bboxes, pred_scores = preds
|
|
412
421
|
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
|
@@ -429,17 +438,18 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
429
438
|
return total_loss
|
|
430
439
|
|
|
431
440
|
@staticmethod
|
|
432
|
-
def get_dn_match_indices(
|
|
433
|
-
|
|
434
|
-
|
|
441
|
+
def get_dn_match_indices(
|
|
442
|
+
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
|
443
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
444
|
+
"""Get match indices for denoising.
|
|
435
445
|
|
|
436
446
|
Args:
|
|
437
|
-
dn_pos_idx (
|
|
447
|
+
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
|
438
448
|
dn_num_group (int): Number of denoising groups.
|
|
439
|
-
gt_groups (
|
|
449
|
+
gt_groups (list[int]): List of integers representing number of ground truths per image.
|
|
440
450
|
|
|
441
451
|
Returns:
|
|
442
|
-
(
|
|
452
|
+
(list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
|
|
443
453
|
"""
|
|
444
454
|
dn_match_indices = []
|
|
445
455
|
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
|
@@ -447,8 +457,9 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
|
447
457
|
if num_gt > 0:
|
|
448
458
|
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
|
449
459
|
gt_idx = gt_idx.repeat(dn_num_group)
|
|
450
|
-
assert len(dn_pos_idx[i]) == len(gt_idx),
|
|
451
|
-
|
|
460
|
+
assert len(dn_pos_idx[i]) == len(gt_idx), (
|
|
461
|
+
f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
|
462
|
+
)
|
|
452
463
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
|
453
464
|
else:
|
|
454
465
|
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|