dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,813 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
from ultralytics.utils.metrics import OKS_SIGMA
|
8
|
+
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
9
|
+
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
10
|
+
from ultralytics.utils.torch_utils import autocast
|
11
|
+
|
12
|
+
from .metrics import bbox_iou, probiou
|
13
|
+
from .tal import bbox2dist
|
14
|
+
|
15
|
+
|
16
|
+
class VarifocalLoss(nn.Module):
|
17
|
+
"""
|
18
|
+
Varifocal loss by Zhang et al.
|
19
|
+
|
20
|
+
https://arxiv.org/abs/2008.13367.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
24
|
+
alpha (float): The balancing factor used to address class imbalance.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, gamma=2.0, alpha=0.75):
|
28
|
+
"""Initialize the VarifocalLoss class."""
|
29
|
+
super().__init__()
|
30
|
+
self.gamma = gamma
|
31
|
+
self.alpha = alpha
|
32
|
+
|
33
|
+
def forward(self, pred_score, gt_score, label):
|
34
|
+
"""Compute varifocal loss between predictions and ground truth."""
|
35
|
+
weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
|
36
|
+
with autocast(enabled=False):
|
37
|
+
loss = (
|
38
|
+
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
39
|
+
.mean(1)
|
40
|
+
.sum()
|
41
|
+
)
|
42
|
+
return loss
|
43
|
+
|
44
|
+
|
45
|
+
class FocalLoss(nn.Module):
|
46
|
+
"""
|
47
|
+
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
48
|
+
|
49
|
+
Args:
|
50
|
+
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
51
|
+
alpha (float | list): The balancing factor used to address class imbalance.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self, gamma=1.5, alpha=0.25):
|
55
|
+
"""Initialize FocalLoss class with no parameters."""
|
56
|
+
super().__init__()
|
57
|
+
self.gamma = gamma
|
58
|
+
self.alpha = torch.tensor(alpha)
|
59
|
+
|
60
|
+
def forward(self, pred, label):
|
61
|
+
"""Calculate focal loss with modulating factors for class imbalance."""
|
62
|
+
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
63
|
+
# p_t = torch.exp(-loss)
|
64
|
+
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
65
|
+
|
66
|
+
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
67
|
+
pred_prob = pred.sigmoid() # prob from logits
|
68
|
+
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
|
69
|
+
modulating_factor = (1.0 - p_t) ** self.gamma
|
70
|
+
loss *= modulating_factor
|
71
|
+
if (self.alpha > 0).any():
|
72
|
+
self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
|
73
|
+
alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
|
74
|
+
loss *= alpha_factor
|
75
|
+
return loss.mean(1).sum()
|
76
|
+
|
77
|
+
|
78
|
+
class DFLoss(nn.Module):
|
79
|
+
"""Criterion class for computing Distribution Focal Loss (DFL)."""
|
80
|
+
|
81
|
+
def __init__(self, reg_max=16) -> None:
|
82
|
+
"""Initialize the DFL module with regularization maximum."""
|
83
|
+
super().__init__()
|
84
|
+
self.reg_max = reg_max
|
85
|
+
|
86
|
+
def __call__(self, pred_dist, target):
|
87
|
+
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
|
88
|
+
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
89
|
+
tl = target.long() # target left
|
90
|
+
tr = tl + 1 # target right
|
91
|
+
wl = tr - target # weight left
|
92
|
+
wr = 1 - wl # weight right
|
93
|
+
return (
|
94
|
+
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
95
|
+
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
96
|
+
).mean(-1, keepdim=True)
|
97
|
+
|
98
|
+
|
99
|
+
class BboxLoss(nn.Module):
|
100
|
+
"""Criterion class for computing training losses for bounding boxes."""
|
101
|
+
|
102
|
+
def __init__(self, reg_max=16):
|
103
|
+
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
104
|
+
super().__init__()
|
105
|
+
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
106
|
+
|
107
|
+
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
108
|
+
"""Compute IoU and DFL losses for bounding boxes."""
|
109
|
+
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
110
|
+
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
111
|
+
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
112
|
+
|
113
|
+
# DFL loss
|
114
|
+
if self.dfl_loss:
|
115
|
+
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
116
|
+
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
117
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
118
|
+
else:
|
119
|
+
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
120
|
+
|
121
|
+
return loss_iou, loss_dfl
|
122
|
+
|
123
|
+
|
124
|
+
class RotatedBboxLoss(BboxLoss):
|
125
|
+
"""Criterion class for computing training losses for rotated bounding boxes."""
|
126
|
+
|
127
|
+
def __init__(self, reg_max):
|
128
|
+
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
129
|
+
super().__init__(reg_max)
|
130
|
+
|
131
|
+
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
132
|
+
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
133
|
+
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
134
|
+
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
135
|
+
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
136
|
+
|
137
|
+
# DFL loss
|
138
|
+
if self.dfl_loss:
|
139
|
+
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
|
140
|
+
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
141
|
+
loss_dfl = loss_dfl.sum() / target_scores_sum
|
142
|
+
else:
|
143
|
+
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
144
|
+
|
145
|
+
return loss_iou, loss_dfl
|
146
|
+
|
147
|
+
|
148
|
+
class KeypointLoss(nn.Module):
|
149
|
+
"""Criterion class for computing keypoint losses."""
|
150
|
+
|
151
|
+
def __init__(self, sigmas) -> None:
|
152
|
+
"""Initialize the KeypointLoss class with keypoint sigmas."""
|
153
|
+
super().__init__()
|
154
|
+
self.sigmas = sigmas
|
155
|
+
|
156
|
+
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
157
|
+
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
|
158
|
+
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
159
|
+
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
160
|
+
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
161
|
+
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
|
162
|
+
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
|
163
|
+
|
164
|
+
|
165
|
+
class v8DetectionLoss:
|
166
|
+
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
167
|
+
|
168
|
+
def __init__(self, model, tal_topk=10): # model must be de-paralleled
|
169
|
+
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
170
|
+
device = next(model.parameters()).device # get model device
|
171
|
+
h = model.args # hyperparameters
|
172
|
+
|
173
|
+
m = model.model[-1] # Detect() module
|
174
|
+
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
175
|
+
self.hyp = h
|
176
|
+
self.stride = m.stride # model strides
|
177
|
+
self.nc = m.nc # number of classes
|
178
|
+
self.no = m.nc + m.reg_max * 4
|
179
|
+
self.reg_max = m.reg_max
|
180
|
+
self.device = device
|
181
|
+
|
182
|
+
self.use_dfl = m.reg_max > 1
|
183
|
+
|
184
|
+
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
185
|
+
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
186
|
+
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
187
|
+
|
188
|
+
def preprocess(self, targets, batch_size, scale_tensor):
|
189
|
+
"""Preprocess targets by converting to tensor format and scaling coordinates."""
|
190
|
+
nl, ne = targets.shape
|
191
|
+
if nl == 0:
|
192
|
+
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
|
193
|
+
else:
|
194
|
+
i = targets[:, 0] # image index
|
195
|
+
_, counts = i.unique(return_counts=True)
|
196
|
+
counts = counts.to(dtype=torch.int32)
|
197
|
+
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
|
198
|
+
for j in range(batch_size):
|
199
|
+
matches = i == j
|
200
|
+
if n := matches.sum():
|
201
|
+
out[j, :n] = targets[matches, 1:]
|
202
|
+
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
203
|
+
return out
|
204
|
+
|
205
|
+
def bbox_decode(self, anchor_points, pred_dist):
|
206
|
+
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
207
|
+
if self.use_dfl:
|
208
|
+
b, a, c = pred_dist.shape # batch, anchors, channels
|
209
|
+
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
210
|
+
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
211
|
+
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
212
|
+
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
213
|
+
|
214
|
+
def __call__(self, preds, batch):
|
215
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
216
|
+
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
217
|
+
feats = preds[1] if isinstance(preds, tuple) else preds
|
218
|
+
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
219
|
+
(self.reg_max * 4, self.nc), 1
|
220
|
+
)
|
221
|
+
|
222
|
+
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
223
|
+
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
224
|
+
|
225
|
+
dtype = pred_scores.dtype
|
226
|
+
batch_size = pred_scores.shape[0]
|
227
|
+
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
228
|
+
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
229
|
+
|
230
|
+
# Targets
|
231
|
+
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
232
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
233
|
+
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
234
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
235
|
+
|
236
|
+
# Pboxes
|
237
|
+
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
238
|
+
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
239
|
+
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
240
|
+
|
241
|
+
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
242
|
+
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
243
|
+
pred_scores.detach().sigmoid(),
|
244
|
+
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
245
|
+
anchor_points * stride_tensor,
|
246
|
+
gt_labels,
|
247
|
+
gt_bboxes,
|
248
|
+
mask_gt,
|
249
|
+
)
|
250
|
+
|
251
|
+
target_scores_sum = max(target_scores.sum(), 1)
|
252
|
+
|
253
|
+
# Cls loss
|
254
|
+
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
255
|
+
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
256
|
+
|
257
|
+
# Bbox loss
|
258
|
+
if fg_mask.sum():
|
259
|
+
target_bboxes /= stride_tensor
|
260
|
+
loss[0], loss[2] = self.bbox_loss(
|
261
|
+
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
262
|
+
)
|
263
|
+
|
264
|
+
loss[0] *= self.hyp.box # box gain
|
265
|
+
loss[1] *= self.hyp.cls # cls gain
|
266
|
+
loss[2] *= self.hyp.dfl # dfl gain
|
267
|
+
|
268
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
269
|
+
|
270
|
+
|
271
|
+
class v8SegmentationLoss(v8DetectionLoss):
|
272
|
+
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
273
|
+
|
274
|
+
def __init__(self, model): # model must be de-paralleled
|
275
|
+
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
276
|
+
super().__init__(model)
|
277
|
+
self.overlap = model.args.overlap_mask
|
278
|
+
|
279
|
+
def __call__(self, preds, batch):
|
280
|
+
"""Calculate and return the combined loss for detection and segmentation."""
|
281
|
+
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
|
282
|
+
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
283
|
+
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
284
|
+
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
285
|
+
(self.reg_max * 4, self.nc), 1
|
286
|
+
)
|
287
|
+
|
288
|
+
# B, grids, ..
|
289
|
+
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
290
|
+
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
291
|
+
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
292
|
+
|
293
|
+
dtype = pred_scores.dtype
|
294
|
+
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
295
|
+
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
296
|
+
|
297
|
+
# Targets
|
298
|
+
try:
|
299
|
+
batch_idx = batch["batch_idx"].view(-1, 1)
|
300
|
+
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
301
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
302
|
+
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
303
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
304
|
+
except RuntimeError as e:
|
305
|
+
raise TypeError(
|
306
|
+
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
307
|
+
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
308
|
+
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
309
|
+
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
310
|
+
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
311
|
+
) from e
|
312
|
+
|
313
|
+
# Pboxes
|
314
|
+
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
315
|
+
|
316
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
317
|
+
pred_scores.detach().sigmoid(),
|
318
|
+
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
319
|
+
anchor_points * stride_tensor,
|
320
|
+
gt_labels,
|
321
|
+
gt_bboxes,
|
322
|
+
mask_gt,
|
323
|
+
)
|
324
|
+
|
325
|
+
target_scores_sum = max(target_scores.sum(), 1)
|
326
|
+
|
327
|
+
# Cls loss
|
328
|
+
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
329
|
+
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
330
|
+
|
331
|
+
if fg_mask.sum():
|
332
|
+
# Bbox loss
|
333
|
+
loss[0], loss[3] = self.bbox_loss(
|
334
|
+
pred_distri,
|
335
|
+
pred_bboxes,
|
336
|
+
anchor_points,
|
337
|
+
target_bboxes / stride_tensor,
|
338
|
+
target_scores,
|
339
|
+
target_scores_sum,
|
340
|
+
fg_mask,
|
341
|
+
)
|
342
|
+
# Masks loss
|
343
|
+
masks = batch["masks"].to(self.device).float()
|
344
|
+
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
345
|
+
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
346
|
+
|
347
|
+
loss[1] = self.calculate_segmentation_loss(
|
348
|
+
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
|
349
|
+
)
|
350
|
+
|
351
|
+
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
352
|
+
else:
|
353
|
+
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
354
|
+
|
355
|
+
loss[0] *= self.hyp.box # box gain
|
356
|
+
loss[1] *= self.hyp.box # seg gain
|
357
|
+
loss[2] *= self.hyp.cls # cls gain
|
358
|
+
loss[3] *= self.hyp.dfl # dfl gain
|
359
|
+
|
360
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
361
|
+
|
362
|
+
@staticmethod
|
363
|
+
def single_mask_loss(
|
364
|
+
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
365
|
+
) -> torch.Tensor:
|
366
|
+
"""
|
367
|
+
Compute the instance segmentation loss for a single image.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
|
371
|
+
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
|
372
|
+
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
373
|
+
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
|
374
|
+
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
(torch.Tensor): The calculated mask loss for a single image.
|
378
|
+
|
379
|
+
Notes:
|
380
|
+
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
381
|
+
predicted masks from the prototype masks and predicted mask coefficients.
|
382
|
+
"""
|
383
|
+
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
384
|
+
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
385
|
+
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
386
|
+
|
387
|
+
def calculate_segmentation_loss(
|
388
|
+
self,
|
389
|
+
fg_mask: torch.Tensor,
|
390
|
+
masks: torch.Tensor,
|
391
|
+
target_gt_idx: torch.Tensor,
|
392
|
+
target_bboxes: torch.Tensor,
|
393
|
+
batch_idx: torch.Tensor,
|
394
|
+
proto: torch.Tensor,
|
395
|
+
pred_masks: torch.Tensor,
|
396
|
+
imgsz: torch.Tensor,
|
397
|
+
overlap: bool,
|
398
|
+
) -> torch.Tensor:
|
399
|
+
"""
|
400
|
+
Calculate the loss for instance segmentation.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
404
|
+
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
|
405
|
+
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
|
406
|
+
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
|
407
|
+
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
|
408
|
+
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
409
|
+
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
410
|
+
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
411
|
+
overlap (bool): Whether the masks in `masks` tensor overlap.
|
412
|
+
|
413
|
+
Returns:
|
414
|
+
(torch.Tensor): The calculated loss for instance segmentation.
|
415
|
+
|
416
|
+
Notes:
|
417
|
+
The batch loss can be computed for improved speed at higher memory usage.
|
418
|
+
For example, pred_mask can be computed as follows:
|
419
|
+
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
|
420
|
+
"""
|
421
|
+
_, _, mask_h, mask_w = proto.shape
|
422
|
+
loss = 0
|
423
|
+
|
424
|
+
# Normalize to 0-1
|
425
|
+
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
|
426
|
+
|
427
|
+
# Areas of target bboxes
|
428
|
+
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
|
429
|
+
|
430
|
+
# Normalize to mask size
|
431
|
+
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
|
432
|
+
|
433
|
+
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
|
434
|
+
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
435
|
+
if fg_mask_i.any():
|
436
|
+
mask_idx = target_gt_idx_i[fg_mask_i]
|
437
|
+
if overlap:
|
438
|
+
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
439
|
+
gt_mask = gt_mask.float()
|
440
|
+
else:
|
441
|
+
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
442
|
+
|
443
|
+
loss += self.single_mask_loss(
|
444
|
+
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
|
445
|
+
)
|
446
|
+
|
447
|
+
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
448
|
+
else:
|
449
|
+
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
450
|
+
|
451
|
+
return loss / fg_mask.sum()
|
452
|
+
|
453
|
+
|
454
|
+
class v8PoseLoss(v8DetectionLoss):
|
455
|
+
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
456
|
+
|
457
|
+
def __init__(self, model): # model must be de-paralleled
|
458
|
+
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
459
|
+
super().__init__(model)
|
460
|
+
self.kpt_shape = model.model[-1].kpt_shape
|
461
|
+
self.bce_pose = nn.BCEWithLogitsLoss()
|
462
|
+
is_pose = self.kpt_shape == [17, 3]
|
463
|
+
nkpt = self.kpt_shape[0] # number of keypoints
|
464
|
+
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
465
|
+
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
466
|
+
|
467
|
+
def __call__(self, preds, batch):
|
468
|
+
"""Calculate the total loss and detach it for pose estimation."""
|
469
|
+
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
470
|
+
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
471
|
+
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
472
|
+
(self.reg_max * 4, self.nc), 1
|
473
|
+
)
|
474
|
+
|
475
|
+
# B, grids, ..
|
476
|
+
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
477
|
+
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
478
|
+
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
479
|
+
|
480
|
+
dtype = pred_scores.dtype
|
481
|
+
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
482
|
+
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
483
|
+
|
484
|
+
# Targets
|
485
|
+
batch_size = pred_scores.shape[0]
|
486
|
+
batch_idx = batch["batch_idx"].view(-1, 1)
|
487
|
+
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
488
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
489
|
+
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
490
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
491
|
+
|
492
|
+
# Pboxes
|
493
|
+
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
494
|
+
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
495
|
+
|
496
|
+
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
497
|
+
pred_scores.detach().sigmoid(),
|
498
|
+
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
499
|
+
anchor_points * stride_tensor,
|
500
|
+
gt_labels,
|
501
|
+
gt_bboxes,
|
502
|
+
mask_gt,
|
503
|
+
)
|
504
|
+
|
505
|
+
target_scores_sum = max(target_scores.sum(), 1)
|
506
|
+
|
507
|
+
# Cls loss
|
508
|
+
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
509
|
+
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
510
|
+
|
511
|
+
# Bbox loss
|
512
|
+
if fg_mask.sum():
|
513
|
+
target_bboxes /= stride_tensor
|
514
|
+
loss[0], loss[4] = self.bbox_loss(
|
515
|
+
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
516
|
+
)
|
517
|
+
keypoints = batch["keypoints"].to(self.device).float().clone()
|
518
|
+
keypoints[..., 0] *= imgsz[1]
|
519
|
+
keypoints[..., 1] *= imgsz[0]
|
520
|
+
|
521
|
+
loss[1], loss[2] = self.calculate_keypoints_loss(
|
522
|
+
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
523
|
+
)
|
524
|
+
|
525
|
+
loss[0] *= self.hyp.box # box gain
|
526
|
+
loss[1] *= self.hyp.pose # pose gain
|
527
|
+
loss[2] *= self.hyp.kobj # kobj gain
|
528
|
+
loss[3] *= self.hyp.cls # cls gain
|
529
|
+
loss[4] *= self.hyp.dfl # dfl gain
|
530
|
+
|
531
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
532
|
+
|
533
|
+
@staticmethod
|
534
|
+
def kpts_decode(anchor_points, pred_kpts):
|
535
|
+
"""Decode predicted keypoints to image coordinates."""
|
536
|
+
y = pred_kpts.clone()
|
537
|
+
y[..., :2] *= 2.0
|
538
|
+
y[..., 0] += anchor_points[:, [0]] - 0.5
|
539
|
+
y[..., 1] += anchor_points[:, [1]] - 0.5
|
540
|
+
return y
|
541
|
+
|
542
|
+
def calculate_keypoints_loss(
|
543
|
+
self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
544
|
+
):
|
545
|
+
"""
|
546
|
+
Calculate the keypoints loss for the model.
|
547
|
+
|
548
|
+
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
549
|
+
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
550
|
+
a binary classification loss that classifies whether a keypoint is present or not.
|
551
|
+
|
552
|
+
Args:
|
553
|
+
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
554
|
+
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
555
|
+
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
556
|
+
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
557
|
+
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
558
|
+
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
559
|
+
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
560
|
+
|
561
|
+
Returns:
|
562
|
+
kpts_loss (torch.Tensor): The keypoints loss.
|
563
|
+
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
564
|
+
"""
|
565
|
+
batch_idx = batch_idx.flatten()
|
566
|
+
batch_size = len(masks)
|
567
|
+
|
568
|
+
# Find the maximum number of keypoints in a single image
|
569
|
+
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
|
570
|
+
|
571
|
+
# Create a tensor to hold batched keypoints
|
572
|
+
batched_keypoints = torch.zeros(
|
573
|
+
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
|
574
|
+
)
|
575
|
+
|
576
|
+
# TODO: any idea how to vectorize this?
|
577
|
+
# Fill batched_keypoints with keypoints based on batch_idx
|
578
|
+
for i in range(batch_size):
|
579
|
+
keypoints_i = keypoints[batch_idx == i]
|
580
|
+
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
|
581
|
+
|
582
|
+
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
|
583
|
+
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
|
584
|
+
|
585
|
+
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
|
586
|
+
selected_keypoints = batched_keypoints.gather(
|
587
|
+
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
588
|
+
)
|
589
|
+
|
590
|
+
# Divide coordinates by stride
|
591
|
+
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
592
|
+
|
593
|
+
kpts_loss = 0
|
594
|
+
kpts_obj_loss = 0
|
595
|
+
|
596
|
+
if masks.any():
|
597
|
+
gt_kpt = selected_keypoints[masks]
|
598
|
+
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
599
|
+
pred_kpt = pred_kpts[masks]
|
600
|
+
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
601
|
+
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
602
|
+
|
603
|
+
if pred_kpt.shape[-1] == 3:
|
604
|
+
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
605
|
+
|
606
|
+
return kpts_loss, kpts_obj_loss
|
607
|
+
|
608
|
+
|
609
|
+
class v8ClassificationLoss:
|
610
|
+
"""Criterion class for computing training losses for classification."""
|
611
|
+
|
612
|
+
def __call__(self, preds, batch):
|
613
|
+
"""Compute the classification loss between predictions and true labels."""
|
614
|
+
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
615
|
+
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
616
|
+
loss_items = loss.detach()
|
617
|
+
return loss, loss_items
|
618
|
+
|
619
|
+
|
620
|
+
class v8OBBLoss(v8DetectionLoss):
|
621
|
+
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
622
|
+
|
623
|
+
def __init__(self, model):
|
624
|
+
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
625
|
+
super().__init__(model)
|
626
|
+
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
627
|
+
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
628
|
+
|
629
|
+
def preprocess(self, targets, batch_size, scale_tensor):
|
630
|
+
"""Preprocess targets for oriented bounding box detection."""
|
631
|
+
if targets.shape[0] == 0:
|
632
|
+
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
633
|
+
else:
|
634
|
+
i = targets[:, 0] # image index
|
635
|
+
_, counts = i.unique(return_counts=True)
|
636
|
+
counts = counts.to(dtype=torch.int32)
|
637
|
+
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
638
|
+
for j in range(batch_size):
|
639
|
+
matches = i == j
|
640
|
+
if n := matches.sum():
|
641
|
+
bboxes = targets[matches, 2:]
|
642
|
+
bboxes[..., :4].mul_(scale_tensor)
|
643
|
+
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
644
|
+
return out
|
645
|
+
|
646
|
+
def __call__(self, preds, batch):
|
647
|
+
"""Calculate and return the loss for oriented bounding box detection."""
|
648
|
+
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
649
|
+
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
650
|
+
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
651
|
+
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
652
|
+
(self.reg_max * 4, self.nc), 1
|
653
|
+
)
|
654
|
+
|
655
|
+
# b, grids, ..
|
656
|
+
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
657
|
+
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
658
|
+
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
659
|
+
|
660
|
+
dtype = pred_scores.dtype
|
661
|
+
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
662
|
+
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
663
|
+
|
664
|
+
# targets
|
665
|
+
try:
|
666
|
+
batch_idx = batch["batch_idx"].view(-1, 1)
|
667
|
+
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
668
|
+
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
669
|
+
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
670
|
+
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
671
|
+
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
672
|
+
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
673
|
+
except RuntimeError as e:
|
674
|
+
raise TypeError(
|
675
|
+
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
676
|
+
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
677
|
+
"i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
|
678
|
+
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
679
|
+
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
680
|
+
) from e
|
681
|
+
|
682
|
+
# Pboxes
|
683
|
+
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
684
|
+
|
685
|
+
bboxes_for_assigner = pred_bboxes.clone().detach()
|
686
|
+
# Only the first four elements need to be scaled
|
687
|
+
bboxes_for_assigner[..., :4] *= stride_tensor
|
688
|
+
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
689
|
+
pred_scores.detach().sigmoid(),
|
690
|
+
bboxes_for_assigner.type(gt_bboxes.dtype),
|
691
|
+
anchor_points * stride_tensor,
|
692
|
+
gt_labels,
|
693
|
+
gt_bboxes,
|
694
|
+
mask_gt,
|
695
|
+
)
|
696
|
+
|
697
|
+
target_scores_sum = max(target_scores.sum(), 1)
|
698
|
+
|
699
|
+
# Cls loss
|
700
|
+
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
701
|
+
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
702
|
+
|
703
|
+
# Bbox loss
|
704
|
+
if fg_mask.sum():
|
705
|
+
target_bboxes[..., :4] /= stride_tensor
|
706
|
+
loss[0], loss[2] = self.bbox_loss(
|
707
|
+
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
708
|
+
)
|
709
|
+
else:
|
710
|
+
loss[0] += (pred_angle * 0).sum()
|
711
|
+
|
712
|
+
loss[0] *= self.hyp.box # box gain
|
713
|
+
loss[1] *= self.hyp.cls # cls gain
|
714
|
+
loss[2] *= self.hyp.dfl # dfl gain
|
715
|
+
|
716
|
+
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
717
|
+
|
718
|
+
def bbox_decode(self, anchor_points, pred_dist, pred_angle):
|
719
|
+
"""
|
720
|
+
Decode predicted object bounding box coordinates from anchor points and distribution.
|
721
|
+
|
722
|
+
Args:
|
723
|
+
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
724
|
+
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
725
|
+
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
726
|
+
|
727
|
+
Returns:
|
728
|
+
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
|
729
|
+
"""
|
730
|
+
if self.use_dfl:
|
731
|
+
b, a, c = pred_dist.shape # batch, anchors, channels
|
732
|
+
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
733
|
+
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
734
|
+
|
735
|
+
|
736
|
+
class E2EDetectLoss:
|
737
|
+
"""Criterion class for computing training losses for end-to-end detection."""
|
738
|
+
|
739
|
+
def __init__(self, model):
|
740
|
+
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
|
741
|
+
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
742
|
+
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
743
|
+
|
744
|
+
def __call__(self, preds, batch):
|
745
|
+
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
746
|
+
preds = preds[1] if isinstance(preds, tuple) else preds
|
747
|
+
one2many = preds["one2many"]
|
748
|
+
loss_one2many = self.one2many(one2many, batch)
|
749
|
+
one2one = preds["one2one"]
|
750
|
+
loss_one2one = self.one2one(one2one, batch)
|
751
|
+
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
752
|
+
|
753
|
+
|
754
|
+
class TVPDetectLoss:
|
755
|
+
"""Criterion class for computing training losses for text-visual prompt detection."""
|
756
|
+
|
757
|
+
def __init__(self, model):
|
758
|
+
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
759
|
+
self.vp_criterion = v8DetectionLoss(model)
|
760
|
+
# NOTE: store following info as it's changeable in __call__
|
761
|
+
self.ori_nc = self.vp_criterion.nc
|
762
|
+
self.ori_no = self.vp_criterion.no
|
763
|
+
self.ori_reg_max = self.vp_criterion.reg_max
|
764
|
+
|
765
|
+
def __call__(self, preds, batch):
|
766
|
+
"""Calculate the loss for text-visual prompt detection."""
|
767
|
+
feats = preds[1] if isinstance(preds, tuple) else preds
|
768
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
769
|
+
|
770
|
+
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
771
|
+
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
772
|
+
return loss, loss.detach()
|
773
|
+
|
774
|
+
vp_feats = self._get_vp_features(feats)
|
775
|
+
vp_loss = self.vp_criterion(vp_feats, batch)
|
776
|
+
box_loss = vp_loss[0][1]
|
777
|
+
return box_loss, vp_loss[1]
|
778
|
+
|
779
|
+
def _get_vp_features(self, feats):
|
780
|
+
"""Extract visual-prompt features from the model output."""
|
781
|
+
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
|
782
|
+
|
783
|
+
self.vp_criterion.nc = vnc
|
784
|
+
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
785
|
+
self.vp_criterion.assigner.num_classes = vnc
|
786
|
+
|
787
|
+
return [
|
788
|
+
torch.cat((box, cls_vp), dim=1)
|
789
|
+
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
790
|
+
]
|
791
|
+
|
792
|
+
|
793
|
+
class TVPSegmentLoss(TVPDetectLoss):
|
794
|
+
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
795
|
+
|
796
|
+
def __init__(self, model):
|
797
|
+
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
798
|
+
super().__init__(model)
|
799
|
+
self.vp_criterion = v8SegmentationLoss(model)
|
800
|
+
|
801
|
+
def __call__(self, preds, batch):
|
802
|
+
"""Calculate the loss for text-visual prompt segmentation."""
|
803
|
+
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
804
|
+
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
805
|
+
|
806
|
+
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
807
|
+
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
808
|
+
return loss, loss.detach()
|
809
|
+
|
810
|
+
vp_feats = self._get_vp_features(feats)
|
811
|
+
vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
|
812
|
+
cls_loss = vp_loss[0][2]
|
813
|
+
return cls_loss, vp_loss[1]
|