dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/nn/modules/head.py
CHANGED
|
@@ -15,17 +15,16 @@ from ultralytics.utils import NOT_MACOS14
|
|
|
15
15
|
from ultralytics.utils.tal import dist2bbox, dist2rbox, make_anchors
|
|
16
16
|
from ultralytics.utils.torch_utils import TORCH_1_11, fuse_conv_and_bn, smart_inference_mode
|
|
17
17
|
|
|
18
|
-
from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
|
|
18
|
+
from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Proto26, RealNVP, Residual, SwiGLUFFN
|
|
19
19
|
from .conv import Conv, DWConv
|
|
20
20
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
|
21
21
|
from .utils import bias_init_with_prob, linear_init
|
|
22
22
|
|
|
23
|
-
__all__ = "
|
|
23
|
+
__all__ = "OBB", "Classify", "Detect", "Pose", "RTDETRDecoder", "Segment", "YOLOEDetect", "YOLOESegment", "v10Detect"
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class Detect(nn.Module):
|
|
27
|
-
"""
|
|
28
|
-
YOLO Detect head for object detection models.
|
|
27
|
+
"""YOLO Detect head for object detection models.
|
|
29
28
|
|
|
30
29
|
This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.
|
|
31
30
|
It supports both training and inference modes, with optional end-to-end detection capabilities.
|
|
@@ -69,7 +68,6 @@ class Detect(nn.Module):
|
|
|
69
68
|
dynamic = False # force grid reconstruction
|
|
70
69
|
export = False # export mode
|
|
71
70
|
format = None # export format
|
|
72
|
-
end2end = False # end2end
|
|
73
71
|
max_det = 300 # max_det
|
|
74
72
|
shape = None
|
|
75
73
|
anchors = torch.empty(0) # init
|
|
@@ -77,18 +75,19 @@ class Detect(nn.Module):
|
|
|
77
75
|
legacy = False # backward compatibility for v3/v5/v8/v9 models
|
|
78
76
|
xyxy = False # xyxy or xywh output
|
|
79
77
|
|
|
80
|
-
def __init__(self, nc: int = 80, ch: tuple = ()):
|
|
81
|
-
"""
|
|
82
|
-
Initialize the YOLO detection layer with specified number of classes and channels.
|
|
78
|
+
def __init__(self, nc: int = 80, reg_max=16, end2end=False, ch: tuple = ()):
|
|
79
|
+
"""Initialize the YOLO detection layer with specified number of classes and channels.
|
|
83
80
|
|
|
84
81
|
Args:
|
|
85
82
|
nc (int): Number of classes.
|
|
83
|
+
reg_max (int): Maximum number of DFL channels.
|
|
84
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
86
85
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
87
86
|
"""
|
|
88
87
|
super().__init__()
|
|
89
88
|
self.nc = nc # number of classes
|
|
90
89
|
self.nl = len(ch) # number of detection layers
|
|
91
|
-
self.reg_max =
|
|
90
|
+
self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
|
92
91
|
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
|
93
92
|
self.stride = torch.zeros(self.nl) # strides computed during build
|
|
94
93
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
|
@@ -109,93 +108,88 @@ class Detect(nn.Module):
|
|
|
109
108
|
)
|
|
110
109
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
|
111
110
|
|
|
112
|
-
if
|
|
111
|
+
if end2end:
|
|
113
112
|
self.one2one_cv2 = copy.deepcopy(self.cv2)
|
|
114
113
|
self.one2one_cv3 = copy.deepcopy(self.cv3)
|
|
115
114
|
|
|
116
|
-
|
|
117
|
-
|
|
115
|
+
@property
|
|
116
|
+
def one2many(self):
|
|
117
|
+
"""Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
|
|
118
|
+
return dict(box_head=self.cv2, cls_head=self.cv3)
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def one2one(self):
|
|
122
|
+
"""Returns the one-to-one head components."""
|
|
123
|
+
return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def end2end(self):
|
|
127
|
+
"""Checks if the model has one2one for v5/v5/v8/v9/11 backward compatibility."""
|
|
128
|
+
return hasattr(self, "one2one")
|
|
129
|
+
|
|
130
|
+
def forward_head(
|
|
131
|
+
self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None
|
|
132
|
+
) -> dict[str, torch.Tensor]:
|
|
133
|
+
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
|
134
|
+
if box_head is None or cls_head is None: # for fused inference
|
|
135
|
+
return dict()
|
|
136
|
+
bs = x[0].shape[0] # batch size
|
|
137
|
+
boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
|
|
138
|
+
scores = torch.cat([cls_head[i](x[i]).view(bs, self.nc, -1) for i in range(self.nl)], dim=-1)
|
|
139
|
+
return dict(boxes=boxes, scores=scores, feats=x)
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self, x: list[torch.Tensor]
|
|
143
|
+
) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
144
|
+
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
|
145
|
+
preds = self.forward_head(x, **self.one2many)
|
|
118
146
|
if self.end2end:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
|
|
129
|
-
"""
|
|
130
|
-
Perform forward pass of the v10Detect module.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
x (list[torch.Tensor]): Input feature maps from different levels.
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.
|
|
137
|
-
Inference mode returns processed detections or tuple with detections and raw outputs.
|
|
138
|
-
"""
|
|
139
|
-
x_detach = [xi.detach() for xi in x]
|
|
140
|
-
one2one = [
|
|
141
|
-
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
|
|
142
|
-
]
|
|
143
|
-
for i in range(self.nl):
|
|
144
|
-
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
|
145
|
-
if self.training: # Training path
|
|
146
|
-
return {"one2many": x, "one2one": one2one}
|
|
147
|
+
x_detach = [xi.detach() for xi in x]
|
|
148
|
+
one2one = self.forward_head(x_detach, **self.one2one)
|
|
149
|
+
preds = {"one2many": preds, "one2one": one2one}
|
|
150
|
+
if self.training:
|
|
151
|
+
return preds
|
|
152
|
+
y = self._inference(preds["one2one"] if self.end2end else preds)
|
|
153
|
+
if self.end2end:
|
|
154
|
+
y = self.postprocess(y.permute(0, 2, 1))
|
|
155
|
+
return y if self.export else (y, preds)
|
|
147
156
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
return y if self.export else (y, {"one2many": x, "one2one": one2one})
|
|
151
|
-
|
|
152
|
-
def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
153
|
-
"""
|
|
154
|
-
Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
|
|
157
|
+
def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
158
|
+
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
|
|
155
159
|
|
|
156
160
|
Args:
|
|
157
|
-
x (
|
|
161
|
+
x (dict[str, torch.Tensor]): List of feature maps from different detection layers.
|
|
158
162
|
|
|
159
163
|
Returns:
|
|
160
164
|
(torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
|
|
161
165
|
"""
|
|
162
166
|
# Inference path
|
|
163
|
-
|
|
164
|
-
|
|
167
|
+
dbox = self._get_decode_boxes(x)
|
|
168
|
+
return torch.cat((dbox, x["scores"].sigmoid()), 1)
|
|
169
|
+
|
|
170
|
+
def _get_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
171
|
+
"""Get decoded boxes based on anchors and strides."""
|
|
172
|
+
shape = x["feats"][0].shape # BCHW
|
|
165
173
|
if self.dynamic or self.shape != shape:
|
|
166
|
-
self.anchors, self.strides = (
|
|
174
|
+
self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
|
|
167
175
|
self.shape = shape
|
|
168
176
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
cls = x_cat[:, self.reg_max * 4 :]
|
|
172
|
-
else:
|
|
173
|
-
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
174
|
-
|
|
175
|
-
if self.export and self.format in {"tflite", "edgetpu"}:
|
|
176
|
-
# Precompute normalization factor to increase numerical stability
|
|
177
|
-
# See https://github.com/ultralytics/ultralytics/issues/7371
|
|
178
|
-
grid_h = shape[2]
|
|
179
|
-
grid_w = shape[3]
|
|
180
|
-
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
|
181
|
-
norm = self.strides / (self.stride[0] * grid_size)
|
|
182
|
-
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
|
183
|
-
else:
|
|
184
|
-
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
|
185
|
-
return torch.cat((dbox, cls.sigmoid()), 1)
|
|
177
|
+
dbox = self.decode_bboxes(self.dfl(x["boxes"]), self.anchors.unsqueeze(0)) * self.strides
|
|
178
|
+
return dbox
|
|
186
179
|
|
|
187
180
|
def bias_init(self):
|
|
188
181
|
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
|
182
|
+
for i, (a, b) in enumerate(zip(self.one2many["box_head"], self.one2many["cls_head"])): # from
|
|
183
|
+
a[-1].bias.data[:] = 2.0 # box
|
|
184
|
+
b[-1].bias.data[: self.nc] = math.log(
|
|
185
|
+
5 / self.nc / (640 / self.stride[i]) ** 2
|
|
186
|
+
) # cls (.01 objects, 80 classes, 640 img)
|
|
195
187
|
if self.end2end:
|
|
196
|
-
for a, b
|
|
197
|
-
a[-1].bias.data[:] =
|
|
198
|
-
b[-1].bias.data[:
|
|
188
|
+
for i, (a, b) in enumerate(zip(self.one2one["box_head"], self.one2one["cls_head"])): # from
|
|
189
|
+
a[-1].bias.data[:] = 2.0 # box
|
|
190
|
+
b[-1].bias.data[: self.nc] = math.log(
|
|
191
|
+
5 / self.nc / (640 / self.stride[i]) ** 2
|
|
192
|
+
) # cls (.01 objects, 80 classes, 640 img)
|
|
199
193
|
|
|
200
194
|
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
|
|
201
195
|
"""Decode bounding boxes from predictions."""
|
|
@@ -206,34 +200,49 @@ class Detect(nn.Module):
|
|
|
206
200
|
dim=1,
|
|
207
201
|
)
|
|
208
202
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
"""
|
|
212
|
-
Post-process YOLO model predictions.
|
|
203
|
+
def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
|
|
204
|
+
"""Post-processes YOLO model predictions.
|
|
213
205
|
|
|
214
206
|
Args:
|
|
215
207
|
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
|
|
216
208
|
format [x, y, w, h, class_probs].
|
|
217
|
-
max_det (int): Maximum detections per image.
|
|
218
|
-
nc (int, optional): Number of classes.
|
|
219
209
|
|
|
220
210
|
Returns:
|
|
221
211
|
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
|
|
222
212
|
dimension format [x, y, w, h, max_class_prob, class_index].
|
|
223
213
|
"""
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
214
|
+
boxes, scores = preds.split([4, self.nc], dim=-1)
|
|
215
|
+
scores, conf, idx = self.get_topk_index(scores, self.max_det)
|
|
216
|
+
boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
|
|
217
|
+
return torch.cat([boxes, scores, conf], dim=-1)
|
|
218
|
+
|
|
219
|
+
def get_topk_index(self, scores: torch.Tensor, max_det: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
220
|
+
"""Get top-k indices from scores.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
scores (torch.Tensor): Scores tensor with shape (batch_size, num_anchors, num_classes).
|
|
224
|
+
max_det (int): Maximum detections per image.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
(torch.Tensor, torch.Tensor, torch.Tensor): Top scores, class indices, and filtered indices.
|
|
228
|
+
"""
|
|
229
|
+
batch_size, anchors, nc = scores.shape # i.e. shape(16,8400,84)
|
|
230
|
+
# Use max_det directly during export for TensorRT compatibility (requires k to be constant),
|
|
231
|
+
# otherwise use min(max_det, anchors) for safety with small inputs during Python inference
|
|
232
|
+
k = max_det if self.export else min(max_det, anchors)
|
|
233
|
+
ori_index = scores.max(dim=-1)[0].topk(k)[1].unsqueeze(-1)
|
|
234
|
+
scores = scores.gather(dim=1, index=ori_index.repeat(1, 1, nc))
|
|
235
|
+
scores, index = scores.flatten(1).topk(k)
|
|
236
|
+
idx = ori_index[torch.arange(batch_size)[..., None], index // nc] # original index
|
|
237
|
+
return scores[..., None], (index % nc)[..., None].float(), idx
|
|
238
|
+
|
|
239
|
+
def fuse(self) -> None:
|
|
240
|
+
"""Remove the one2many head for inference optimization."""
|
|
241
|
+
self.cv2 = self.cv3 = None
|
|
232
242
|
|
|
233
243
|
|
|
234
244
|
class Segment(Detect):
|
|
235
|
-
"""
|
|
236
|
-
YOLO Segment head for segmentation models.
|
|
245
|
+
"""YOLO Segment head for segmentation models.
|
|
237
246
|
|
|
238
247
|
This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
|
|
239
248
|
|
|
@@ -253,39 +262,150 @@ class Segment(Detect):
|
|
|
253
262
|
>>> outputs = segment(x)
|
|
254
263
|
"""
|
|
255
264
|
|
|
256
|
-
def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
|
|
257
|
-
"""
|
|
258
|
-
Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
|
|
265
|
+
def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
|
|
266
|
+
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
|
|
259
267
|
|
|
260
268
|
Args:
|
|
261
269
|
nc (int): Number of classes.
|
|
262
270
|
nm (int): Number of masks.
|
|
263
271
|
npr (int): Number of protos.
|
|
272
|
+
reg_max (int): Maximum number of DFL channels.
|
|
273
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
264
274
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
265
275
|
"""
|
|
266
|
-
super().__init__(nc, ch)
|
|
276
|
+
super().__init__(nc, reg_max, end2end, ch)
|
|
267
277
|
self.nm = nm # number of masks
|
|
268
278
|
self.npr = npr # number of protos
|
|
269
279
|
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
|
270
280
|
|
|
271
281
|
c4 = max(ch[0] // 4, self.nm)
|
|
272
282
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
|
283
|
+
if end2end:
|
|
284
|
+
self.one2one_cv4 = copy.deepcopy(self.cv4)
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def one2many(self):
|
|
288
|
+
"""Returns the one-to-many head components, here for backward compatibility."""
|
|
289
|
+
return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv4)
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def one2one(self):
|
|
293
|
+
"""Returns the one-to-one head components."""
|
|
294
|
+
return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, mask_head=self.one2one_cv4)
|
|
273
295
|
|
|
274
|
-
def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor]:
|
|
296
|
+
def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
|
|
275
297
|
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
|
276
|
-
|
|
277
|
-
|
|
298
|
+
outputs = super().forward(x)
|
|
299
|
+
preds = outputs[1] if isinstance(outputs, tuple) else outputs
|
|
300
|
+
proto = self.proto(x[0]) # mask protos
|
|
301
|
+
if isinstance(preds, dict): # training and validating during training
|
|
302
|
+
if self.end2end:
|
|
303
|
+
preds["one2many"]["proto"] = proto
|
|
304
|
+
preds["one2one"]["proto"] = proto.detach()
|
|
305
|
+
else:
|
|
306
|
+
preds["proto"] = proto
|
|
307
|
+
if self.training:
|
|
308
|
+
return preds
|
|
309
|
+
return (outputs, proto) if self.export else ((outputs[0], proto), preds)
|
|
310
|
+
|
|
311
|
+
def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
312
|
+
"""Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
|
|
313
|
+
preds = super()._inference(x)
|
|
314
|
+
return torch.cat([preds, x["mask_coefficient"]], dim=1)
|
|
315
|
+
|
|
316
|
+
def forward_head(
|
|
317
|
+
self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, mask_head: torch.nn.Module
|
|
318
|
+
) -> torch.Tensor:
|
|
319
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
|
|
320
|
+
preds = super().forward_head(x, box_head, cls_head)
|
|
321
|
+
if mask_head is not None:
|
|
322
|
+
bs = x[0].shape[0] # batch size
|
|
323
|
+
preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
|
|
324
|
+
return preds
|
|
325
|
+
|
|
326
|
+
def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
|
|
327
|
+
"""Post-process YOLO model predictions.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
|
|
331
|
+
format [x, y, w, h, class_probs, mask_coefficient].
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
|
|
335
|
+
dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
|
|
336
|
+
"""
|
|
337
|
+
boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
|
|
338
|
+
scores, conf, idx = self.get_topk_index(scores, self.max_det)
|
|
339
|
+
boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
|
|
340
|
+
mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
|
|
341
|
+
return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
|
|
342
|
+
|
|
343
|
+
def fuse(self) -> None:
|
|
344
|
+
"""Remove the one2many head for inference optimization."""
|
|
345
|
+
self.cv2 = self.cv3 = self.cv4 = None
|
|
278
346
|
|
|
279
|
-
|
|
280
|
-
|
|
347
|
+
|
|
348
|
+
class Segment26(Segment):
|
|
349
|
+
"""YOLO26 Segment head for segmentation models.
|
|
350
|
+
|
|
351
|
+
This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
|
|
352
|
+
|
|
353
|
+
Attributes:
|
|
354
|
+
nm (int): Number of masks.
|
|
355
|
+
npr (int): Number of protos.
|
|
356
|
+
proto (Proto): Prototype generation module.
|
|
357
|
+
cv4 (nn.ModuleList): Convolution layers for mask coefficients.
|
|
358
|
+
|
|
359
|
+
Methods:
|
|
360
|
+
forward: Return model outputs and mask coefficients.
|
|
361
|
+
|
|
362
|
+
Examples:
|
|
363
|
+
Create a segmentation head
|
|
364
|
+
>>> segment = Segment26(nc=80, nm=32, npr=256, ch=(256, 512, 1024))
|
|
365
|
+
>>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
|
|
366
|
+
>>> outputs = segment(x)
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
|
|
370
|
+
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
nc (int): Number of classes.
|
|
374
|
+
nm (int): Number of masks.
|
|
375
|
+
npr (int): Number of protos.
|
|
376
|
+
reg_max (int): Maximum number of DFL channels.
|
|
377
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
378
|
+
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
379
|
+
"""
|
|
380
|
+
super().__init__(nc, nm, npr, reg_max, end2end, ch)
|
|
381
|
+
self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
|
|
382
|
+
|
|
383
|
+
def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
|
|
384
|
+
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
|
385
|
+
outputs = Detect.forward(self, x)
|
|
386
|
+
preds = outputs[1] if isinstance(outputs, tuple) else outputs
|
|
387
|
+
proto = self.proto(x) # mask protos
|
|
388
|
+
if isinstance(preds, dict): # training and validating during training
|
|
389
|
+
if self.end2end:
|
|
390
|
+
preds["one2many"]["proto"] = proto
|
|
391
|
+
preds["one2one"]["proto"] = (
|
|
392
|
+
tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach()
|
|
393
|
+
)
|
|
394
|
+
else:
|
|
395
|
+
preds["proto"] = proto
|
|
281
396
|
if self.training:
|
|
282
|
-
return
|
|
283
|
-
return (
|
|
397
|
+
return preds
|
|
398
|
+
return (outputs, proto) if self.export else ((outputs[0], proto), preds)
|
|
399
|
+
|
|
400
|
+
def fuse(self) -> None:
|
|
401
|
+
"""Remove the one2many head and extra part of proto module for inference optimization."""
|
|
402
|
+
super().fuse()
|
|
403
|
+
if hasattr(self.proto, "fuse"):
|
|
404
|
+
self.proto.fuse()
|
|
284
405
|
|
|
285
406
|
|
|
286
407
|
class OBB(Detect):
|
|
287
|
-
"""
|
|
288
|
-
YOLO OBB detection head for detection with rotation models.
|
|
408
|
+
"""YOLO OBB detection head for detection with rotation models.
|
|
289
409
|
|
|
290
410
|
This class extends the Detect head to include oriented bounding box prediction with rotation angles.
|
|
291
411
|
|
|
@@ -305,43 +425,117 @@ class OBB(Detect):
|
|
|
305
425
|
>>> outputs = obb(x)
|
|
306
426
|
"""
|
|
307
427
|
|
|
308
|
-
def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
|
|
309
|
-
"""
|
|
310
|
-
Initialize OBB with number of classes `nc` and layer channels `ch`.
|
|
428
|
+
def __init__(self, nc: int = 80, ne: int = 1, reg_max=16, end2end=False, ch: tuple = ()):
|
|
429
|
+
"""Initialize OBB with number of classes `nc` and layer channels `ch`.
|
|
311
430
|
|
|
312
431
|
Args:
|
|
313
432
|
nc (int): Number of classes.
|
|
314
433
|
ne (int): Number of extra parameters.
|
|
434
|
+
reg_max (int): Maximum number of DFL channels.
|
|
435
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
315
436
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
316
437
|
"""
|
|
317
|
-
super().__init__(nc, ch)
|
|
438
|
+
super().__init__(nc, reg_max, end2end, ch)
|
|
318
439
|
self.ne = ne # number of extra parameters
|
|
319
440
|
|
|
320
441
|
c4 = max(ch[0] // 4, self.ne)
|
|
321
442
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
443
|
+
if end2end:
|
|
444
|
+
self.one2one_cv4 = copy.deepcopy(self.cv4)
|
|
445
|
+
|
|
446
|
+
@property
|
|
447
|
+
def one2many(self):
|
|
448
|
+
"""Returns the one-to-many head components, here for backward compatibility."""
|
|
449
|
+
return dict(box_head=self.cv2, cls_head=self.cv3, angle_head=self.cv4)
|
|
450
|
+
|
|
451
|
+
@property
|
|
452
|
+
def one2one(self):
|
|
453
|
+
"""Returns the one-to-one head components."""
|
|
454
|
+
return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, angle_head=self.one2one_cv4)
|
|
455
|
+
|
|
456
|
+
def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
457
|
+
"""Decode predicted bounding boxes and class probabilities, concatenated with rotation angles."""
|
|
458
|
+
# For decode_bboxes convenience
|
|
459
|
+
self.angle = x["angle"] # TODO: need to test obb
|
|
460
|
+
preds = super()._inference(x)
|
|
461
|
+
return torch.cat([preds, x["angle"]], dim=1)
|
|
462
|
+
|
|
463
|
+
def forward_head(
|
|
464
|
+
self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
|
|
465
|
+
) -> torch.Tensor:
|
|
466
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and angles."""
|
|
467
|
+
preds = super().forward_head(x, box_head, cls_head)
|
|
468
|
+
if angle_head is not None:
|
|
469
|
+
bs = x[0].shape[0] # batch size
|
|
470
|
+
angle = torch.cat(
|
|
471
|
+
[angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
|
|
472
|
+
) # OBB theta logits
|
|
473
|
+
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
|
|
474
|
+
preds["angle"] = angle
|
|
475
|
+
return preds
|
|
336
476
|
|
|
337
477
|
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
|
|
338
478
|
"""Decode rotated bounding boxes."""
|
|
339
479
|
return dist2rbox(bboxes, self.angle, anchors, dim=1)
|
|
340
480
|
|
|
481
|
+
def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
|
|
482
|
+
"""Post-process YOLO model predictions.
|
|
341
483
|
|
|
342
|
-
|
|
484
|
+
Args:
|
|
485
|
+
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + ne) with last dimension
|
|
486
|
+
format [x, y, w, h, class_probs, angle].
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 7) and last
|
|
490
|
+
dimension format [x, y, w, h, max_class_prob, class_index, angle].
|
|
491
|
+
"""
|
|
492
|
+
boxes, scores, angle = preds.split([4, self.nc, self.ne], dim=-1)
|
|
493
|
+
scores, conf, idx = self.get_topk_index(scores, self.max_det)
|
|
494
|
+
boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
|
|
495
|
+
angle = angle.gather(dim=1, index=idx.repeat(1, 1, self.ne))
|
|
496
|
+
return torch.cat([boxes, scores, conf, angle], dim=-1)
|
|
497
|
+
|
|
498
|
+
def fuse(self) -> None:
|
|
499
|
+
"""Remove the one2many head for inference optimization."""
|
|
500
|
+
self.cv2 = self.cv3 = self.cv4 = None
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class OBB26(OBB):
|
|
504
|
+
"""YOLO26 OBB detection head for detection with rotation models. This class extends the OBB head with modified angle
|
|
505
|
+
processing that outputs raw angle predictions without sigmoid transformation, compared to the original
|
|
506
|
+
OBB class.
|
|
507
|
+
|
|
508
|
+
Attributes:
|
|
509
|
+
ne (int): Number of extra parameters.
|
|
510
|
+
cv4 (nn.ModuleList): Convolution layers for angle prediction.
|
|
511
|
+
angle (torch.Tensor): Predicted rotation angles.
|
|
512
|
+
|
|
513
|
+
Methods:
|
|
514
|
+
forward_head: Concatenate and return predicted bounding boxes, class probabilities, and raw angles.
|
|
515
|
+
|
|
516
|
+
Examples:
|
|
517
|
+
Create an OBB26 detection head
|
|
518
|
+
>>> obb26 = OBB26(nc=80, ne=1, ch=(256, 512, 1024))
|
|
519
|
+
>>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
|
|
520
|
+
>>> outputs = obb26(x).
|
|
343
521
|
"""
|
|
344
|
-
|
|
522
|
+
|
|
523
|
+
def forward_head(
|
|
524
|
+
self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
|
|
525
|
+
) -> torch.Tensor:
|
|
526
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and raw angles."""
|
|
527
|
+
preds = Detect.forward_head(self, x, box_head, cls_head)
|
|
528
|
+
if angle_head is not None:
|
|
529
|
+
bs = x[0].shape[0] # batch size
|
|
530
|
+
angle = torch.cat(
|
|
531
|
+
[angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
|
|
532
|
+
) # OBB theta logits (raw output without sigmoid transformation)
|
|
533
|
+
preds["angle"] = angle
|
|
534
|
+
return preds
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class Pose(Detect):
|
|
538
|
+
"""YOLO Pose head for keypoints models.
|
|
345
539
|
|
|
346
540
|
This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
|
|
347
541
|
|
|
@@ -361,50 +555,78 @@ class Pose(Detect):
|
|
|
361
555
|
>>> outputs = pose(x)
|
|
362
556
|
"""
|
|
363
557
|
|
|
364
|
-
def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
|
|
365
|
-
"""
|
|
366
|
-
Initialize YOLO network with default parameters and Convolutional Layers.
|
|
558
|
+
def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
|
|
559
|
+
"""Initialize YOLO network with default parameters and Convolutional Layers.
|
|
367
560
|
|
|
368
561
|
Args:
|
|
369
562
|
nc (int): Number of classes.
|
|
370
563
|
kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
|
|
564
|
+
reg_max (int): Maximum number of DFL channels.
|
|
565
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
371
566
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
372
567
|
"""
|
|
373
|
-
super().__init__(nc, ch)
|
|
568
|
+
super().__init__(nc, reg_max, end2end, ch)
|
|
374
569
|
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
|
375
570
|
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
|
376
571
|
|
|
377
572
|
c4 = max(ch[0] // 4, self.nk)
|
|
378
573
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
|
574
|
+
if end2end:
|
|
575
|
+
self.one2one_cv4 = copy.deepcopy(self.cv4)
|
|
576
|
+
|
|
577
|
+
@property
|
|
578
|
+
def one2many(self):
|
|
579
|
+
"""Returns the one-to-many head components, here for backward compatibility."""
|
|
580
|
+
return dict(box_head=self.cv2, cls_head=self.cv3, pose_head=self.cv4)
|
|
581
|
+
|
|
582
|
+
@property
|
|
583
|
+
def one2one(self):
|
|
584
|
+
"""Returns the one-to-one head components."""
|
|
585
|
+
return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, pose_head=self.one2one_cv4)
|
|
586
|
+
|
|
587
|
+
def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
588
|
+
"""Decode predicted bounding boxes and class probabilities, concatenated with keypoints."""
|
|
589
|
+
preds = super()._inference(x)
|
|
590
|
+
return torch.cat([preds, self.kpts_decode(x["kpts"])], dim=1)
|
|
591
|
+
|
|
592
|
+
def forward_head(
|
|
593
|
+
self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, pose_head: torch.nn.Module
|
|
594
|
+
) -> torch.Tensor:
|
|
595
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
|
|
596
|
+
preds = super().forward_head(x, box_head, cls_head)
|
|
597
|
+
if pose_head is not None:
|
|
598
|
+
bs = x[0].shape[0] # batch size
|
|
599
|
+
preds["kpts"] = torch.cat([pose_head[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
|
|
600
|
+
return preds
|
|
601
|
+
|
|
602
|
+
def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
|
|
603
|
+
"""Post-process YOLO model predictions.
|
|
379
604
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
|
384
|
-
x = Detect.forward(self, x)
|
|
385
|
-
if self.training:
|
|
386
|
-
return x, kpt
|
|
387
|
-
pred_kpt = self.kpts_decode(bs, kpt)
|
|
388
|
-
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
|
|
605
|
+
Args:
|
|
606
|
+
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nk) with last dimension
|
|
607
|
+
format [x, y, w, h, class_probs, keypoints].
|
|
389
608
|
|
|
390
|
-
|
|
609
|
+
Returns:
|
|
610
|
+
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + self.nk) and
|
|
611
|
+
last dimension format [x, y, w, h, max_class_prob, class_index, keypoints].
|
|
612
|
+
"""
|
|
613
|
+
boxes, scores, kpts = preds.split([4, self.nc, self.nk], dim=-1)
|
|
614
|
+
scores, conf, idx = self.get_topk_index(scores, self.max_det)
|
|
615
|
+
boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
|
|
616
|
+
kpts = kpts.gather(dim=1, index=idx.repeat(1, 1, self.nk))
|
|
617
|
+
return torch.cat([boxes, scores, conf, kpts], dim=-1)
|
|
618
|
+
|
|
619
|
+
def fuse(self) -> None:
|
|
620
|
+
"""Remove the one2many head for inference optimization."""
|
|
621
|
+
self.cv2 = self.cv3 = self.cv4 = None
|
|
622
|
+
|
|
623
|
+
def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
|
|
391
624
|
"""Decode keypoints from predictions."""
|
|
392
625
|
ndim = self.kpt_shape[1]
|
|
626
|
+
bs = kpts.shape[0]
|
|
393
627
|
if self.export:
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
"edgetpu",
|
|
397
|
-
}: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
|
398
|
-
# Precompute normalization factor to increase numerical stability
|
|
399
|
-
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
400
|
-
grid_h, grid_w = self.shape[2], self.shape[3]
|
|
401
|
-
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
|
402
|
-
norm = self.strides / (self.stride[0] * grid_size)
|
|
403
|
-
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
|
|
404
|
-
else:
|
|
405
|
-
# NCNN fix
|
|
406
|
-
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
407
|
-
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
|
628
|
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
629
|
+
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
|
408
630
|
if ndim == 3:
|
|
409
631
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
410
632
|
return a.view(bs, self.nk, -1)
|
|
@@ -420,9 +642,125 @@ class Pose(Detect):
|
|
|
420
642
|
return y
|
|
421
643
|
|
|
422
644
|
|
|
423
|
-
class
|
|
645
|
+
class Pose26(Pose):
|
|
646
|
+
"""YOLO26 Pose head for keypoints models.
|
|
647
|
+
|
|
648
|
+
This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
|
|
649
|
+
|
|
650
|
+
Attributes:
|
|
651
|
+
kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).
|
|
652
|
+
nk (int): Total number of keypoint values.
|
|
653
|
+
cv4 (nn.ModuleList): Convolution layers for keypoint prediction.
|
|
654
|
+
|
|
655
|
+
Methods:
|
|
656
|
+
forward: Perform forward pass through YOLO model and return predictions.
|
|
657
|
+
kpts_decode: Decode keypoints from predictions.
|
|
658
|
+
|
|
659
|
+
Examples:
|
|
660
|
+
Create a pose detection head
|
|
661
|
+
>>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))
|
|
662
|
+
>>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
|
|
663
|
+
>>> outputs = pose(x)
|
|
424
664
|
"""
|
|
425
|
-
|
|
665
|
+
|
|
666
|
+
def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
|
|
667
|
+
"""Initialize YOLO network with default parameters and Convolutional Layers.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
nc (int): Number of classes.
|
|
671
|
+
kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
|
|
672
|
+
reg_max (int): Maximum number of DFL channels.
|
|
673
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
674
|
+
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
675
|
+
"""
|
|
676
|
+
super().__init__(nc, kpt_shape, reg_max, end2end, ch)
|
|
677
|
+
self.flow_model = RealNVP()
|
|
678
|
+
|
|
679
|
+
c4 = max(ch[0] // 4, kpt_shape[0] * (kpt_shape[1] + 2))
|
|
680
|
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3)) for x in ch)
|
|
681
|
+
|
|
682
|
+
self.cv4_kpts = nn.ModuleList(nn.Conv2d(c4, self.nk, 1) for _ in ch)
|
|
683
|
+
self.nk_sigma = kpt_shape[0] * 2 # sigma_x, sigma_y for each keypoint
|
|
684
|
+
self.cv4_sigma = nn.ModuleList(nn.Conv2d(c4, self.nk_sigma, 1) for _ in ch)
|
|
685
|
+
|
|
686
|
+
if end2end:
|
|
687
|
+
self.one2one_cv4 = copy.deepcopy(self.cv4)
|
|
688
|
+
self.one2one_cv4_kpts = copy.deepcopy(self.cv4_kpts)
|
|
689
|
+
self.one2one_cv4_sigma = copy.deepcopy(self.cv4_sigma)
|
|
690
|
+
|
|
691
|
+
@property
|
|
692
|
+
def one2many(self):
|
|
693
|
+
"""Returns the one-to-many head components, here for backward compatibility."""
|
|
694
|
+
return dict(
|
|
695
|
+
box_head=self.cv2,
|
|
696
|
+
cls_head=self.cv3,
|
|
697
|
+
pose_head=self.cv4,
|
|
698
|
+
kpts_head=self.cv4_kpts,
|
|
699
|
+
kpts_sigma_head=self.cv4_sigma,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
@property
|
|
703
|
+
def one2one(self):
|
|
704
|
+
"""Returns the one-to-one head components."""
|
|
705
|
+
return dict(
|
|
706
|
+
box_head=self.one2one_cv2,
|
|
707
|
+
cls_head=self.one2one_cv3,
|
|
708
|
+
pose_head=self.one2one_cv4,
|
|
709
|
+
kpts_head=self.one2one_cv4_kpts,
|
|
710
|
+
kpts_sigma_head=self.one2one_cv4_sigma,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
def forward_head(
|
|
714
|
+
self,
|
|
715
|
+
x: list[torch.Tensor],
|
|
716
|
+
box_head: torch.nn.Module,
|
|
717
|
+
cls_head: torch.nn.Module,
|
|
718
|
+
pose_head: torch.nn.Module,
|
|
719
|
+
kpts_head: torch.nn.Module,
|
|
720
|
+
kpts_sigma_head: torch.nn.Module,
|
|
721
|
+
) -> torch.Tensor:
|
|
722
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
|
|
723
|
+
preds = Detect.forward_head(self, x, box_head, cls_head)
|
|
724
|
+
if pose_head is not None:
|
|
725
|
+
bs = x[0].shape[0] # batch size
|
|
726
|
+
features = [pose_head[i](x[i]) for i in range(self.nl)]
|
|
727
|
+
preds["kpts"] = torch.cat([kpts_head[i](features[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
|
|
728
|
+
if self.training:
|
|
729
|
+
preds["kpts_sigma"] = torch.cat(
|
|
730
|
+
[kpts_sigma_head[i](features[i]).view(bs, self.nk_sigma, -1) for i in range(self.nl)], 2
|
|
731
|
+
)
|
|
732
|
+
return preds
|
|
733
|
+
|
|
734
|
+
def fuse(self) -> None:
|
|
735
|
+
"""Remove the one2many head for inference optimization."""
|
|
736
|
+
super().fuse()
|
|
737
|
+
self.cv4_kpts = self.cv4_sigma = self.flow_model = self.one2one_cv4_sigma = None
|
|
738
|
+
|
|
739
|
+
def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
|
|
740
|
+
"""Decode keypoints from predictions."""
|
|
741
|
+
ndim = self.kpt_shape[1]
|
|
742
|
+
bs = kpts.shape[0]
|
|
743
|
+
if self.export:
|
|
744
|
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
745
|
+
# NCNN fix
|
|
746
|
+
a = (y[:, :, :2] + self.anchors) * self.strides
|
|
747
|
+
if ndim == 3:
|
|
748
|
+
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
749
|
+
return a.view(bs, self.nk, -1)
|
|
750
|
+
else:
|
|
751
|
+
y = kpts.clone()
|
|
752
|
+
if ndim == 3:
|
|
753
|
+
if NOT_MACOS14:
|
|
754
|
+
y[:, 2::ndim].sigmoid_()
|
|
755
|
+
else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
|
|
756
|
+
y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
|
|
757
|
+
y[:, 0::ndim] = (y[:, 0::ndim] + self.anchors[0]) * self.strides
|
|
758
|
+
y[:, 1::ndim] = (y[:, 1::ndim] + self.anchors[1]) * self.strides
|
|
759
|
+
return y
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
class Classify(nn.Module):
|
|
763
|
+
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
|
|
426
764
|
|
|
427
765
|
This class implements a classification head that transforms feature maps into class predictions.
|
|
428
766
|
|
|
@@ -446,8 +784,7 @@ class Classify(nn.Module):
|
|
|
446
784
|
export = False # export mode
|
|
447
785
|
|
|
448
786
|
def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
|
|
449
|
-
"""
|
|
450
|
-
Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
|
|
787
|
+
"""Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.
|
|
451
788
|
|
|
452
789
|
Args:
|
|
453
790
|
c1 (int): Number of input channels.
|
|
@@ -476,11 +813,10 @@ class Classify(nn.Module):
|
|
|
476
813
|
|
|
477
814
|
|
|
478
815
|
class WorldDetect(Detect):
|
|
479
|
-
"""
|
|
480
|
-
Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
|
816
|
+
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
|
481
817
|
|
|
482
|
-
This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding
|
|
483
|
-
|
|
818
|
+
This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding in
|
|
819
|
+
object detection tasks.
|
|
484
820
|
|
|
485
821
|
Attributes:
|
|
486
822
|
cv3 (nn.ModuleList): Convolution layers for embedding features.
|
|
@@ -498,30 +834,44 @@ class WorldDetect(Detect):
|
|
|
498
834
|
>>> outputs = world_detect(x, text)
|
|
499
835
|
"""
|
|
500
836
|
|
|
501
|
-
def __init__(
|
|
502
|
-
|
|
503
|
-
|
|
837
|
+
def __init__(
|
|
838
|
+
self,
|
|
839
|
+
nc: int = 80,
|
|
840
|
+
embed: int = 512,
|
|
841
|
+
with_bn: bool = False,
|
|
842
|
+
reg_max: int = 16,
|
|
843
|
+
end2end: bool = False,
|
|
844
|
+
ch: tuple = (),
|
|
845
|
+
):
|
|
846
|
+
"""Initialize YOLO detection layer with nc classes and layer channels ch.
|
|
504
847
|
|
|
505
848
|
Args:
|
|
506
849
|
nc (int): Number of classes.
|
|
507
850
|
embed (int): Embedding dimension.
|
|
508
851
|
with_bn (bool): Whether to use batch normalization in contrastive head.
|
|
852
|
+
reg_max (int): Maximum number of DFL channels.
|
|
853
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
509
854
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
510
855
|
"""
|
|
511
|
-
super().__init__(nc, ch)
|
|
856
|
+
super().__init__(nc, reg_max=reg_max, end2end=end2end, ch=ch)
|
|
512
857
|
c3 = max(ch[0], min(self.nc, 100))
|
|
513
858
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
|
514
859
|
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
|
515
860
|
|
|
516
|
-
def forward(self, x: list[torch.Tensor], text: torch.Tensor) ->
|
|
861
|
+
def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> dict[str, torch.Tensor] | tuple:
|
|
517
862
|
"""Concatenate and return predicted bounding boxes and class probabilities."""
|
|
863
|
+
feats = [xi.clone() for xi in x] # save original features for anchor generation
|
|
518
864
|
for i in range(self.nl):
|
|
519
865
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
|
|
520
|
-
if self.training:
|
|
521
|
-
return x
|
|
522
866
|
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
|
523
|
-
|
|
524
|
-
|
|
867
|
+
bs = x[0].shape[0]
|
|
868
|
+
x_cat = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2)
|
|
869
|
+
boxes, scores = x_cat.split((self.reg_max * 4, self.nc), 1)
|
|
870
|
+
preds = dict(boxes=boxes, scores=scores, feats=feats)
|
|
871
|
+
if self.training:
|
|
872
|
+
return preds
|
|
873
|
+
y = self._inference(preds)
|
|
874
|
+
return y if self.export else (y, preds)
|
|
525
875
|
|
|
526
876
|
def bias_init(self):
|
|
527
877
|
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
|
@@ -534,11 +884,10 @@ class WorldDetect(Detect):
|
|
|
534
884
|
|
|
535
885
|
|
|
536
886
|
class LRPCHead(nn.Module):
|
|
537
|
-
"""
|
|
538
|
-
Lightweight Region Proposal and Classification Head for efficient object detection.
|
|
887
|
+
"""Lightweight Region Proposal and Classification Head for efficient object detection.
|
|
539
888
|
|
|
540
|
-
This head combines region proposal filtering with classification to enable efficient detection with
|
|
541
|
-
|
|
889
|
+
This head combines region proposal filtering with classification to enable efficient detection with dynamic
|
|
890
|
+
vocabulary support.
|
|
542
891
|
|
|
543
892
|
Attributes:
|
|
544
893
|
vocab (nn.Module): Vocabulary/classification layer.
|
|
@@ -559,8 +908,7 @@ class LRPCHead(nn.Module):
|
|
|
559
908
|
"""
|
|
560
909
|
|
|
561
910
|
def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):
|
|
562
|
-
"""
|
|
563
|
-
Initialize LRPCHead with vocabulary, proposal filter, and localization components.
|
|
911
|
+
"""Initialize LRPCHead with vocabulary, proposal filter, and localization components.
|
|
564
912
|
|
|
565
913
|
Args:
|
|
566
914
|
vocab (nn.Module): Vocabulary/classification module.
|
|
@@ -574,7 +922,8 @@ class LRPCHead(nn.Module):
|
|
|
574
922
|
self.loc = loc
|
|
575
923
|
self.enabled = enabled
|
|
576
924
|
|
|
577
|
-
|
|
925
|
+
@staticmethod
|
|
926
|
+
def conv2linear(conv: nn.Conv2d) -> nn.Linear:
|
|
578
927
|
"""Convert a 1x1 convolutional layer to a linear layer."""
|
|
579
928
|
assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
|
|
580
929
|
linear = nn.Linear(conv.in_channels, conv.out_channels)
|
|
@@ -589,18 +938,19 @@ class LRPCHead(nn.Module):
|
|
|
589
938
|
mask = pf_score.sigmoid() > conf
|
|
590
939
|
cls_feat = cls_feat.flatten(2).transpose(-1, -2)
|
|
591
940
|
cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())
|
|
592
|
-
return
|
|
941
|
+
return self.loc(loc_feat), cls_feat.transpose(-1, -2), mask
|
|
593
942
|
else:
|
|
594
943
|
cls_feat = self.vocab(cls_feat)
|
|
595
944
|
loc_feat = self.loc(loc_feat)
|
|
596
|
-
return (
|
|
597
|
-
|
|
945
|
+
return (
|
|
946
|
+
loc_feat,
|
|
947
|
+
cls_feat.flatten(2),
|
|
948
|
+
torch.ones(cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool),
|
|
598
949
|
)
|
|
599
950
|
|
|
600
951
|
|
|
601
952
|
class YOLOEDetect(Detect):
|
|
602
|
-
"""
|
|
603
|
-
Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
|
953
|
+
"""Head for integrating YOLO detection models with semantic understanding from text embeddings.
|
|
604
954
|
|
|
605
955
|
This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding
|
|
606
956
|
through text embeddings and visual prompt embeddings.
|
|
@@ -631,17 +981,20 @@ class YOLOEDetect(Detect):
|
|
|
631
981
|
|
|
632
982
|
is_fused = False
|
|
633
983
|
|
|
634
|
-
def __init__(
|
|
635
|
-
|
|
636
|
-
|
|
984
|
+
def __init__(
|
|
985
|
+
self, nc: int = 80, embed: int = 512, with_bn: bool = False, reg_max=16, end2end=False, ch: tuple = ()
|
|
986
|
+
):
|
|
987
|
+
"""Initialize YOLO detection layer with nc classes and layer channels ch.
|
|
637
988
|
|
|
638
989
|
Args:
|
|
639
990
|
nc (int): Number of classes.
|
|
640
991
|
embed (int): Embedding dimension.
|
|
641
992
|
with_bn (bool): Whether to use batch normalization in contrastive head.
|
|
993
|
+
reg_max (int): Maximum number of DFL channels.
|
|
994
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
642
995
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
643
996
|
"""
|
|
644
|
-
super().__init__(nc, ch)
|
|
997
|
+
super().__init__(nc, reg_max, end2end, ch)
|
|
645
998
|
c3 = max(ch[0], min(self.nc, 100))
|
|
646
999
|
assert c3 <= embed
|
|
647
1000
|
assert with_bn
|
|
@@ -657,29 +1010,43 @@ class YOLOEDetect(Detect):
|
|
|
657
1010
|
for x in ch
|
|
658
1011
|
)
|
|
659
1012
|
)
|
|
660
|
-
|
|
661
1013
|
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
|
1014
|
+
if end2end:
|
|
1015
|
+
self.one2one_cv3 = copy.deepcopy(self.cv3) # overwrite with new cv3
|
|
1016
|
+
self.one2one_cv4 = copy.deepcopy(self.cv4)
|
|
662
1017
|
|
|
663
1018
|
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
|
664
1019
|
self.savpe = SAVPE(ch, c3, embed)
|
|
665
1020
|
self.embed = embed
|
|
666
1021
|
|
|
667
1022
|
@smart_inference_mode()
|
|
668
|
-
def fuse(self, txt_feats: torch.Tensor):
|
|
1023
|
+
def fuse(self, txt_feats: torch.Tensor = None):
|
|
669
1024
|
"""Fuse text features with model weights for efficient inference."""
|
|
1025
|
+
if txt_feats is None: # means eliminate one2many branch
|
|
1026
|
+
self.cv2 = self.cv3 = self.cv4 = None
|
|
1027
|
+
return
|
|
670
1028
|
if self.is_fused:
|
|
671
1029
|
return
|
|
672
1030
|
|
|
673
1031
|
assert not self.training
|
|
674
1032
|
txt_feats = txt_feats.to(torch.float32).squeeze(0)
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
1033
|
+
self._fuse_tp(txt_feats, self.cv3, self.cv4)
|
|
1034
|
+
if self.end2end:
|
|
1035
|
+
self._fuse_tp(txt_feats, self.one2one_cv3, self.one2one_cv4)
|
|
1036
|
+
del self.reprta
|
|
1037
|
+
self.reprta = nn.Identity()
|
|
1038
|
+
self.is_fused = True
|
|
1039
|
+
|
|
1040
|
+
def _fuse_tp(self, txt_feats: torch.Tensor, cls_head: torch.nn.Module, bn_head: torch.nn.Module) -> None:
|
|
1041
|
+
"""Fuse text prompt embeddings with model weights for efficient inference."""
|
|
1042
|
+
for cls_h, bn_h in zip(cls_head, bn_head):
|
|
1043
|
+
assert isinstance(cls_h, nn.Sequential)
|
|
1044
|
+
assert isinstance(bn_h, BNContrastiveHead)
|
|
1045
|
+
conv = cls_h[-1]
|
|
679
1046
|
assert isinstance(conv, nn.Conv2d)
|
|
680
|
-
logit_scale =
|
|
681
|
-
bias =
|
|
682
|
-
norm =
|
|
1047
|
+
logit_scale = bn_h.logit_scale
|
|
1048
|
+
bias = bn_h.bias
|
|
1049
|
+
norm = bn_h.norm
|
|
683
1050
|
|
|
684
1051
|
t = txt_feats * logit_scale.exp()
|
|
685
1052
|
conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
|
|
@@ -703,13 +1070,9 @@ class YOLOEDetect(Detect):
|
|
|
703
1070
|
|
|
704
1071
|
conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
|
|
705
1072
|
conv.bias.data.copy_(b1 + b2)
|
|
706
|
-
|
|
1073
|
+
cls_h[-1] = conv
|
|
707
1074
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
del self.reprta
|
|
711
|
-
self.reprta = nn.Identity()
|
|
712
|
-
self.is_fused = True
|
|
1075
|
+
bn_h.fuse()
|
|
713
1076
|
|
|
714
1077
|
def get_tpe(self, tpe: torch.Tensor | None) -> torch.Tensor | None:
|
|
715
1078
|
"""Get text prompt embeddings with normalization."""
|
|
@@ -724,74 +1087,89 @@ class YOLOEDetect(Detect):
|
|
|
724
1087
|
assert vpe.ndim == 3 # (B, N, D)
|
|
725
1088
|
return vpe
|
|
726
1089
|
|
|
727
|
-
def
|
|
1090
|
+
def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
|
|
1091
|
+
"""Process features with class prompt embeddings to generate detections."""
|
|
1092
|
+
if hasattr(self, "lrpc"): # for prompt-free inference
|
|
1093
|
+
return self.forward_lrpc(x[:3])
|
|
1094
|
+
return super().forward(x)
|
|
1095
|
+
|
|
1096
|
+
def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
|
|
728
1097
|
"""Process features with fused text embeddings to generate detections for prompt-free model."""
|
|
729
|
-
|
|
730
|
-
|
|
1098
|
+
boxes, scores, index = [], [], []
|
|
1099
|
+
bs = x[0].shape[0]
|
|
1100
|
+
cv2 = self.cv2 if not self.end2end else self.one2one_cv2
|
|
1101
|
+
cv3 = self.cv3 if not self.end2end else self.one2one_cv2
|
|
731
1102
|
for i in range(self.nl):
|
|
732
|
-
cls_feat =
|
|
733
|
-
loc_feat =
|
|
1103
|
+
cls_feat = cv3[i](x[i])
|
|
1104
|
+
loc_feat = cv2[i](x[i])
|
|
734
1105
|
assert isinstance(self.lrpc[i], LRPCHead)
|
|
735
|
-
|
|
736
|
-
cls_feat,
|
|
1106
|
+
box, score, idx = self.lrpc[i](
|
|
1107
|
+
cls_feat,
|
|
1108
|
+
loc_feat,
|
|
1109
|
+
0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
|
|
737
1110
|
)
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
"
|
|
767
|
-
if
|
|
768
|
-
return
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
1111
|
+
boxes.append(box.view(bs, self.reg_max * 4, -1))
|
|
1112
|
+
scores.append(score)
|
|
1113
|
+
index.append(idx)
|
|
1114
|
+
preds = dict(boxes=torch.cat(boxes, 2), scores=torch.cat(scores, 2), feats=x, index=torch.cat(index))
|
|
1115
|
+
y = self._inference(preds)
|
|
1116
|
+
if self.end2end:
|
|
1117
|
+
y = self.postprocess(y.permute(0, 2, 1))
|
|
1118
|
+
return y if self.export else (y, preds)
|
|
1119
|
+
|
|
1120
|
+
def _get_decode_boxes(self, x):
|
|
1121
|
+
"""Decode predicted bounding boxes for inference."""
|
|
1122
|
+
dbox = super()._get_decode_boxes(x)
|
|
1123
|
+
if hasattr(self, "lrpc"):
|
|
1124
|
+
dbox = dbox if self.export and not self.dynamic else dbox[..., x["index"]]
|
|
1125
|
+
return dbox
|
|
1126
|
+
|
|
1127
|
+
@property
|
|
1128
|
+
def one2many(self):
|
|
1129
|
+
"""Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
|
|
1130
|
+
return dict(box_head=self.cv2, cls_head=self.cv3, contrastive_head=self.cv4)
|
|
1131
|
+
|
|
1132
|
+
@property
|
|
1133
|
+
def one2one(self):
|
|
1134
|
+
"""Returns the one-to-one head components."""
|
|
1135
|
+
return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, contrastive_head=self.one2one_cv4)
|
|
1136
|
+
|
|
1137
|
+
def forward_head(self, x, box_head, cls_head, contrastive_head):
|
|
1138
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and text embeddings."""
|
|
1139
|
+
assert len(x) == 4, f"Expected 4 features including 3 feature maps and 1 text embeddings, but got {len(x)}."
|
|
1140
|
+
if box_head is None or cls_head is None: # for fused inference
|
|
1141
|
+
return dict()
|
|
1142
|
+
bs = x[0].shape[0] # batch size
|
|
1143
|
+
boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
|
|
1144
|
+
self.nc = x[-1].shape[1]
|
|
1145
|
+
scores = torch.cat(
|
|
1146
|
+
[contrastive_head[i](cls_head[i](x[i]), x[-1]).reshape(bs, self.nc, -1) for i in range(self.nl)], dim=-1
|
|
1147
|
+
)
|
|
773
1148
|
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
|
774
|
-
|
|
775
|
-
return y if self.export else (y, x)
|
|
1149
|
+
return dict(boxes=boxes, scores=scores, feats=x[:3])
|
|
776
1150
|
|
|
777
1151
|
def bias_init(self):
|
|
778
|
-
"""Initialize biases
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
a[-1].bias.data[:] = 1.0 # box
|
|
784
|
-
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
|
1152
|
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
|
1153
|
+
for i, (a, b, c) in enumerate(
|
|
1154
|
+
zip(self.one2many["box_head"], self.one2many["cls_head"], self.one2many["contrastive_head"])
|
|
1155
|
+
):
|
|
1156
|
+
a[-1].bias.data[:] = 2.0 # box
|
|
785
1157
|
b[-1].bias.data[:] = 0.0
|
|
786
|
-
c.bias.data[:] = math.log(5 /
|
|
1158
|
+
c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
|
|
1159
|
+
if self.end2end:
|
|
1160
|
+
for i, (a, b, c) in enumerate(
|
|
1161
|
+
zip(self.one2one["box_head"], self.one2one["cls_head"], self.one2one["contrastive_head"])
|
|
1162
|
+
):
|
|
1163
|
+
a[-1].bias.data[:] = 2.0 # box
|
|
1164
|
+
b[-1].bias.data[:] = 0.0
|
|
1165
|
+
c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
|
|
787
1166
|
|
|
788
1167
|
|
|
789
1168
|
class YOLOESegment(YOLOEDetect):
|
|
790
|
-
"""
|
|
791
|
-
YOLO segmentation head with text embedding capabilities.
|
|
1169
|
+
"""YOLO segmentation head with text embedding capabilities.
|
|
792
1170
|
|
|
793
|
-
This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks
|
|
794
|
-
|
|
1171
|
+
This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks with
|
|
1172
|
+
text-guided semantic understanding.
|
|
795
1173
|
|
|
796
1174
|
Attributes:
|
|
797
1175
|
nm (int): Number of masks.
|
|
@@ -811,10 +1189,17 @@ class YOLOESegment(YOLOEDetect):
|
|
|
811
1189
|
"""
|
|
812
1190
|
|
|
813
1191
|
def __init__(
|
|
814
|
-
self,
|
|
1192
|
+
self,
|
|
1193
|
+
nc: int = 80,
|
|
1194
|
+
nm: int = 32,
|
|
1195
|
+
npr: int = 256,
|
|
1196
|
+
embed: int = 512,
|
|
1197
|
+
with_bn: bool = False,
|
|
1198
|
+
reg_max=16,
|
|
1199
|
+
end2end=False,
|
|
1200
|
+
ch: tuple = (),
|
|
815
1201
|
):
|
|
816
|
-
"""
|
|
817
|
-
Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
|
|
1202
|
+
"""Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
|
|
818
1203
|
|
|
819
1204
|
Args:
|
|
820
1205
|
nc (int): Number of classes.
|
|
@@ -822,41 +1207,195 @@ class YOLOESegment(YOLOEDetect):
|
|
|
822
1207
|
npr (int): Number of protos.
|
|
823
1208
|
embed (int): Embedding dimension.
|
|
824
1209
|
with_bn (bool): Whether to use batch normalization in contrastive head.
|
|
1210
|
+
reg_max (int): Maximum number of DFL channels.
|
|
1211
|
+
end2end (bool): Whether to use end-to-end NMS-free detection.
|
|
825
1212
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
826
1213
|
"""
|
|
827
|
-
super().__init__(nc, embed, with_bn, ch)
|
|
1214
|
+
super().__init__(nc, embed, with_bn, reg_max, end2end, ch)
|
|
828
1215
|
self.nm = nm
|
|
829
1216
|
self.npr = npr
|
|
830
1217
|
self.proto = Proto(ch[0], self.npr, self.nm)
|
|
831
1218
|
|
|
832
1219
|
c5 = max(ch[0] // 4, self.nm)
|
|
833
1220
|
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
|
1221
|
+
if end2end:
|
|
1222
|
+
self.one2one_cv5 = copy.deepcopy(self.cv5)
|
|
1223
|
+
|
|
1224
|
+
@property
|
|
1225
|
+
def one2many(self):
|
|
1226
|
+
"""Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
|
|
1227
|
+
return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv5, contrastive_head=self.cv4)
|
|
1228
|
+
|
|
1229
|
+
@property
|
|
1230
|
+
def one2one(self):
|
|
1231
|
+
"""Returns the one-to-one head components."""
|
|
1232
|
+
return dict(
|
|
1233
|
+
box_head=self.one2one_cv2,
|
|
1234
|
+
cls_head=self.one2one_cv3,
|
|
1235
|
+
mask_head=self.one2one_cv5,
|
|
1236
|
+
contrastive_head=self.one2one_cv4,
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
|
|
1240
|
+
"""Process features with fused text embeddings to generate detections for prompt-free model."""
|
|
1241
|
+
boxes, scores, index = [], [], []
|
|
1242
|
+
bs = x[0].shape[0]
|
|
1243
|
+
cv2 = self.cv2 if not self.end2end else self.one2one_cv2
|
|
1244
|
+
cv3 = self.cv3 if not self.end2end else self.one2one_cv3
|
|
1245
|
+
cv5 = self.cv5 if not self.end2end else self.one2one_cv5
|
|
1246
|
+
for i in range(self.nl):
|
|
1247
|
+
cls_feat = cv3[i](x[i])
|
|
1248
|
+
loc_feat = cv2[i](x[i])
|
|
1249
|
+
assert isinstance(self.lrpc[i], LRPCHead)
|
|
1250
|
+
box, score, idx = self.lrpc[i](
|
|
1251
|
+
cls_feat,
|
|
1252
|
+
loc_feat,
|
|
1253
|
+
0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
|
|
1254
|
+
)
|
|
1255
|
+
boxes.append(box.view(bs, self.reg_max * 4, -1))
|
|
1256
|
+
scores.append(score)
|
|
1257
|
+
index.append(idx)
|
|
1258
|
+
mc = torch.cat([cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
|
|
1259
|
+
index = torch.cat(index)
|
|
1260
|
+
preds = dict(
|
|
1261
|
+
boxes=torch.cat(boxes, 2),
|
|
1262
|
+
scores=torch.cat(scores, 2),
|
|
1263
|
+
feats=x,
|
|
1264
|
+
index=index,
|
|
1265
|
+
mask_coefficient=mc * index.int() if self.export and not self.dynamic else mc[..., index],
|
|
1266
|
+
)
|
|
1267
|
+
y = self._inference(preds)
|
|
1268
|
+
if self.end2end:
|
|
1269
|
+
y = self.postprocess(y.permute(0, 2, 1))
|
|
1270
|
+
return y if self.export else (y, preds)
|
|
834
1271
|
|
|
835
|
-
def forward(self, x: list[torch.Tensor]
|
|
1272
|
+
def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
|
|
836
1273
|
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
|
837
|
-
|
|
838
|
-
|
|
1274
|
+
outputs = super().forward(x)
|
|
1275
|
+
preds = outputs[1] if isinstance(outputs, tuple) else outputs
|
|
1276
|
+
proto = self.proto(x[0]) # mask protos
|
|
1277
|
+
if isinstance(preds, dict): # training and validating during training
|
|
1278
|
+
if self.end2end:
|
|
1279
|
+
preds["one2many"]["proto"] = proto
|
|
1280
|
+
preds["one2one"]["proto"] = proto.detach()
|
|
1281
|
+
else:
|
|
1282
|
+
preds["proto"] = proto
|
|
1283
|
+
if self.training:
|
|
1284
|
+
return preds
|
|
1285
|
+
return (outputs, proto) if self.export else ((outputs[0], proto), preds)
|
|
839
1286
|
|
|
840
|
-
|
|
841
|
-
|
|
1287
|
+
def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
1288
|
+
"""Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
|
|
1289
|
+
preds = super()._inference(x)
|
|
1290
|
+
return torch.cat([preds, x["mask_coefficient"]], dim=1)
|
|
842
1291
|
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
1292
|
+
def forward_head(
|
|
1293
|
+
self,
|
|
1294
|
+
x: list[torch.Tensor],
|
|
1295
|
+
box_head: torch.nn.Module,
|
|
1296
|
+
cls_head: torch.nn.Module,
|
|
1297
|
+
mask_head: torch.nn.Module,
|
|
1298
|
+
contrastive_head: torch.nn.Module,
|
|
1299
|
+
) -> torch.Tensor:
|
|
1300
|
+
"""Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
|
|
1301
|
+
preds = super().forward_head(x, box_head, cls_head, contrastive_head)
|
|
1302
|
+
if mask_head is not None:
|
|
1303
|
+
bs = x[0].shape[0] # batch size
|
|
1304
|
+
preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
|
|
1305
|
+
return preds
|
|
1306
|
+
|
|
1307
|
+
def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
|
|
1308
|
+
"""Post-process YOLO model predictions.
|
|
847
1309
|
|
|
848
|
-
|
|
849
|
-
|
|
1310
|
+
Args:
|
|
1311
|
+
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
|
|
1312
|
+
format [x, y, w, h, class_probs, mask_coefficient].
|
|
850
1313
|
|
|
851
|
-
|
|
852
|
-
|
|
1314
|
+
Returns:
|
|
1315
|
+
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
|
|
1316
|
+
dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
|
|
1317
|
+
"""
|
|
1318
|
+
boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
|
|
1319
|
+
scores, conf, idx = self.get_topk_index(scores, self.max_det)
|
|
1320
|
+
boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
|
|
1321
|
+
mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
|
|
1322
|
+
return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
|
|
853
1323
|
|
|
854
|
-
|
|
1324
|
+
def fuse(self, txt_feats: torch.Tensor = None):
|
|
1325
|
+
"""Fuse text features with model weights for efficient inference."""
|
|
1326
|
+
super().fuse(txt_feats)
|
|
1327
|
+
if txt_feats is None: # means eliminate one2many branch
|
|
1328
|
+
self.cv5 = None
|
|
1329
|
+
if hasattr(self.proto, "fuse"):
|
|
1330
|
+
self.proto.fuse()
|
|
1331
|
+
return
|
|
855
1332
|
|
|
856
1333
|
|
|
857
|
-
class
|
|
1334
|
+
class YOLOESegment26(YOLOESegment):
|
|
1335
|
+
"""YOLOE-style segmentation head module using Proto26 for mask generation.
|
|
1336
|
+
|
|
1337
|
+
This class extends the YOLOEDetect functionality to include segmentation capabilities by integrating a prototype
|
|
1338
|
+
generation module and convolutional layers to predict mask coefficients.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
nc (int): Number of classes. Defaults to 80.
|
|
1342
|
+
nm (int): Number of masks. Defaults to 32.
|
|
1343
|
+
npr (int): Number of prototype channels. Defaults to 256.
|
|
1344
|
+
embed (int): Embedding dimensionality. Defaults to 512.
|
|
1345
|
+
with_bn (bool): Whether to use Batch Normalization. Defaults to False.
|
|
1346
|
+
reg_max (int): Maximum regression value for bounding boxes. Defaults to 16.
|
|
1347
|
+
end2end (bool): Whether to use end-to-end detection mode. Defaults to False.
|
|
1348
|
+
ch (tuple[int, ...]): Input channels for each scale.
|
|
1349
|
+
|
|
1350
|
+
Attributes:
|
|
1351
|
+
nm (int): Number of segmentation masks.
|
|
1352
|
+
npr (int): Number of prototype channels.
|
|
1353
|
+
proto (Proto26): Prototype generation module for segmentation.
|
|
1354
|
+
cv5 (nn.ModuleList): Convolutional layers for generating mask coefficients from features.
|
|
1355
|
+
one2one_cv5 (nn.ModuleList, optional): Deep copy of cv5 for end-to-end detection branches.
|
|
858
1356
|
"""
|
|
859
|
-
|
|
1357
|
+
|
|
1358
|
+
def __init__(
|
|
1359
|
+
self,
|
|
1360
|
+
nc: int = 80,
|
|
1361
|
+
nm: int = 32,
|
|
1362
|
+
npr: int = 256,
|
|
1363
|
+
embed: int = 512,
|
|
1364
|
+
with_bn: bool = False,
|
|
1365
|
+
reg_max=16,
|
|
1366
|
+
end2end=False,
|
|
1367
|
+
ch: tuple = (),
|
|
1368
|
+
):
|
|
1369
|
+
"""Initialize YOLOESegment26 with class count, mask parameters, and embedding dimensions."""
|
|
1370
|
+
YOLOEDetect.__init__(self, nc, embed, with_bn, reg_max, end2end, ch)
|
|
1371
|
+
self.nm = nm
|
|
1372
|
+
self.npr = npr
|
|
1373
|
+
self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
|
|
1374
|
+
|
|
1375
|
+
c5 = max(ch[0] // 4, self.nm)
|
|
1376
|
+
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
|
1377
|
+
if end2end:
|
|
1378
|
+
self.one2one_cv5 = copy.deepcopy(self.cv5)
|
|
1379
|
+
|
|
1380
|
+
def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
|
|
1381
|
+
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
|
1382
|
+
outputs = YOLOEDetect.forward(self, x)
|
|
1383
|
+
preds = outputs[1] if isinstance(outputs, tuple) else outputs
|
|
1384
|
+
proto = self.proto([xi.detach() for xi in x], return_semseg=False) # mask protos
|
|
1385
|
+
|
|
1386
|
+
if isinstance(preds, dict): # training and validating during training
|
|
1387
|
+
if self.end2end and not hasattr(self, "lrpc"): # not prompt-free
|
|
1388
|
+
preds["one2many"]["proto"] = proto
|
|
1389
|
+
preds["one2one"]["proto"] = proto.detach()
|
|
1390
|
+
else:
|
|
1391
|
+
preds["proto"] = proto
|
|
1392
|
+
if self.training:
|
|
1393
|
+
return preds
|
|
1394
|
+
return (outputs, proto) if self.export else ((outputs[0], proto), preds)
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
class RTDETRDecoder(nn.Module):
|
|
1398
|
+
"""Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
|
860
1399
|
|
|
861
1400
|
This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
|
|
862
1401
|
and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
|
|
@@ -920,8 +1459,7 @@ class RTDETRDecoder(nn.Module):
|
|
|
920
1459
|
box_noise_scale: float = 1.0,
|
|
921
1460
|
learnt_init_query: bool = False,
|
|
922
1461
|
):
|
|
923
|
-
"""
|
|
924
|
-
Initialize the RTDETRDecoder module with the given parameters.
|
|
1462
|
+
"""Initialize the RTDETRDecoder module with the given parameters.
|
|
925
1463
|
|
|
926
1464
|
Args:
|
|
927
1465
|
nc (int): Number of classes.
|
|
@@ -981,8 +1519,7 @@ class RTDETRDecoder(nn.Module):
|
|
|
981
1519
|
self._reset_parameters()
|
|
982
1520
|
|
|
983
1521
|
def forward(self, x: list[torch.Tensor], batch: dict | None = None) -> tuple | torch.Tensor:
|
|
984
|
-
"""
|
|
985
|
-
Run the forward pass of the module, returning bounding box and classification scores for the input.
|
|
1522
|
+
"""Run the forward pass of the module, returning bounding box and classification scores for the input.
|
|
986
1523
|
|
|
987
1524
|
Args:
|
|
988
1525
|
x (list[torch.Tensor]): List of feature maps from the backbone.
|
|
@@ -1030,16 +1567,15 @@ class RTDETRDecoder(nn.Module):
|
|
|
1030
1567
|
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
|
1031
1568
|
return y if self.export else (y, x)
|
|
1032
1569
|
|
|
1570
|
+
@staticmethod
|
|
1033
1571
|
def _generate_anchors(
|
|
1034
|
-
self,
|
|
1035
1572
|
shapes: list[list[int]],
|
|
1036
1573
|
grid_size: float = 0.05,
|
|
1037
1574
|
dtype: torch.dtype = torch.float32,
|
|
1038
1575
|
device: str = "cpu",
|
|
1039
1576
|
eps: float = 1e-2,
|
|
1040
1577
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1041
|
-
"""
|
|
1042
|
-
Generate anchor bounding boxes for given shapes with specific grid size and validate them.
|
|
1578
|
+
"""Generate anchor bounding boxes for given shapes with specific grid size and validate them.
|
|
1043
1579
|
|
|
1044
1580
|
Args:
|
|
1045
1581
|
shapes (list): List of feature map shapes.
|
|
@@ -1071,8 +1607,7 @@ class RTDETRDecoder(nn.Module):
|
|
|
1071
1607
|
return anchors, valid_mask
|
|
1072
1608
|
|
|
1073
1609
|
def _get_encoder_input(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]]]:
|
|
1074
|
-
"""
|
|
1075
|
-
Process and return encoder inputs by getting projection features from input and concatenating them.
|
|
1610
|
+
"""Process and return encoder inputs by getting projection features from input and concatenating them.
|
|
1076
1611
|
|
|
1077
1612
|
Args:
|
|
1078
1613
|
x (list[torch.Tensor]): List of feature maps from the backbone.
|
|
@@ -1104,8 +1639,7 @@ class RTDETRDecoder(nn.Module):
|
|
|
1104
1639
|
dn_embed: torch.Tensor | None = None,
|
|
1105
1640
|
dn_bbox: torch.Tensor | None = None,
|
|
1106
1641
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1107
|
-
"""
|
|
1108
|
-
Generate and prepare the input required for the decoder from the provided features and shapes.
|
|
1642
|
+
"""Generate and prepare the input required for the decoder from the provided features and shapes.
|
|
1109
1643
|
|
|
1110
1644
|
Args:
|
|
1111
1645
|
feats (torch.Tensor): Processed features from encoder.
|
|
@@ -1129,9 +1663,9 @@ class RTDETRDecoder(nn.Module):
|
|
|
1129
1663
|
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
|
1130
1664
|
|
|
1131
1665
|
# Query selection
|
|
1132
|
-
# (bs,
|
|
1666
|
+
# (bs*num_queries,)
|
|
1133
1667
|
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
|
1134
|
-
# (bs,
|
|
1668
|
+
# (bs*num_queries,)
|
|
1135
1669
|
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
|
1136
1670
|
|
|
1137
1671
|
# (bs, num_queries, 256)
|
|
@@ -1183,11 +1717,10 @@ class RTDETRDecoder(nn.Module):
|
|
|
1183
1717
|
|
|
1184
1718
|
|
|
1185
1719
|
class v10Detect(Detect):
|
|
1186
|
-
"""
|
|
1187
|
-
v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
|
1720
|
+
"""v10 Detection head from https://arxiv.org/pdf/2405.14458.
|
|
1188
1721
|
|
|
1189
|
-
This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions
|
|
1190
|
-
|
|
1722
|
+
This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions for
|
|
1723
|
+
improved efficiency and performance.
|
|
1191
1724
|
|
|
1192
1725
|
Attributes:
|
|
1193
1726
|
end2end (bool): End-to-end detection mode.
|
|
@@ -1211,14 +1744,13 @@ class v10Detect(Detect):
|
|
|
1211
1744
|
end2end = True
|
|
1212
1745
|
|
|
1213
1746
|
def __init__(self, nc: int = 80, ch: tuple = ()):
|
|
1214
|
-
"""
|
|
1215
|
-
Initialize the v10Detect object with the specified number of classes and input channels.
|
|
1747
|
+
"""Initialize the v10Detect object with the specified number of classes and input channels.
|
|
1216
1748
|
|
|
1217
1749
|
Args:
|
|
1218
1750
|
nc (int): Number of classes.
|
|
1219
1751
|
ch (tuple): Tuple of channel sizes from backbone feature maps.
|
|
1220
1752
|
"""
|
|
1221
|
-
super().__init__(nc, ch)
|
|
1753
|
+
super().__init__(nc, end2end=True, ch=ch)
|
|
1222
1754
|
c3 = max(ch[0], min(self.nc, 100)) # channels
|
|
1223
1755
|
# Light cls head
|
|
1224
1756
|
self.cv3 = nn.ModuleList(
|
|
@@ -1233,4 +1765,4 @@ class v10Detect(Detect):
|
|
|
1233
1765
|
|
|
1234
1766
|
def fuse(self):
|
|
1235
1767
|
"""Remove the one2many head for inference optimization."""
|
|
1236
|
-
self.cv2 = self.cv3 =
|
|
1768
|
+
self.cv2 = self.cv3 = None
|