dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,713 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""Transformer modules."""
|
3
|
+
|
4
|
+
import math
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch.nn.init import constant_, xavier_uniform_
|
10
|
+
|
11
|
+
from .conv import Conv
|
12
|
+
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
|
13
|
+
|
14
|
+
__all__ = (
|
15
|
+
"TransformerEncoderLayer",
|
16
|
+
"TransformerLayer",
|
17
|
+
"TransformerBlock",
|
18
|
+
"MLPBlock",
|
19
|
+
"LayerNorm2d",
|
20
|
+
"AIFI",
|
21
|
+
"DeformableTransformerDecoder",
|
22
|
+
"DeformableTransformerDecoderLayer",
|
23
|
+
"MSDeformAttn",
|
24
|
+
"MLP",
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class TransformerEncoderLayer(nn.Module):
|
29
|
+
"""
|
30
|
+
Defines a single layer of the transformer encoder.
|
31
|
+
|
32
|
+
Attributes:
|
33
|
+
ma (nn.MultiheadAttention): Multi-head attention module.
|
34
|
+
fc1 (nn.Linear): First linear layer in the feedforward network.
|
35
|
+
fc2 (nn.Linear): Second linear layer in the feedforward network.
|
36
|
+
norm1 (nn.LayerNorm): Layer normalization after attention.
|
37
|
+
norm2 (nn.LayerNorm): Layer normalization after feedforward network.
|
38
|
+
dropout (nn.Dropout): Dropout layer for the feedforward network.
|
39
|
+
dropout1 (nn.Dropout): Dropout layer after attention.
|
40
|
+
dropout2 (nn.Dropout): Dropout layer after feedforward network.
|
41
|
+
act (nn.Module): Activation function.
|
42
|
+
normalize_before (bool): Whether to apply normalization before attention and feedforward.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
|
46
|
+
"""
|
47
|
+
Initialize the TransformerEncoderLayer with specified parameters.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
c1 (int): Input dimension.
|
51
|
+
cm (int): Hidden dimension in the feedforward network.
|
52
|
+
num_heads (int): Number of attention heads.
|
53
|
+
dropout (float): Dropout probability.
|
54
|
+
act (nn.Module): Activation function.
|
55
|
+
normalize_before (bool): Whether to apply normalization before attention and feedforward.
|
56
|
+
"""
|
57
|
+
super().__init__()
|
58
|
+
from ...utils.torch_utils import TORCH_1_9
|
59
|
+
|
60
|
+
if not TORCH_1_9:
|
61
|
+
raise ModuleNotFoundError(
|
62
|
+
"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
|
63
|
+
)
|
64
|
+
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
65
|
+
# Implementation of Feedforward model
|
66
|
+
self.fc1 = nn.Linear(c1, cm)
|
67
|
+
self.fc2 = nn.Linear(cm, c1)
|
68
|
+
|
69
|
+
self.norm1 = nn.LayerNorm(c1)
|
70
|
+
self.norm2 = nn.LayerNorm(c1)
|
71
|
+
self.dropout = nn.Dropout(dropout)
|
72
|
+
self.dropout1 = nn.Dropout(dropout)
|
73
|
+
self.dropout2 = nn.Dropout(dropout)
|
74
|
+
|
75
|
+
self.act = act
|
76
|
+
self.normalize_before = normalize_before
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
def with_pos_embed(tensor, pos=None):
|
80
|
+
"""Add position embeddings to the tensor if provided."""
|
81
|
+
return tensor if pos is None else tensor + pos
|
82
|
+
|
83
|
+
def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
84
|
+
"""
|
85
|
+
Perform forward pass with post-normalization.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
src (torch.Tensor): Input tensor.
|
89
|
+
src_mask (torch.Tensor, optional): Mask for the src sequence.
|
90
|
+
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
|
91
|
+
pos (torch.Tensor, optional): Positional encoding.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
(torch.Tensor): Output tensor after attention and feedforward.
|
95
|
+
"""
|
96
|
+
q = k = self.with_pos_embed(src, pos)
|
97
|
+
src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
98
|
+
src = src + self.dropout1(src2)
|
99
|
+
src = self.norm1(src)
|
100
|
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
|
101
|
+
src = src + self.dropout2(src2)
|
102
|
+
return self.norm2(src)
|
103
|
+
|
104
|
+
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
105
|
+
"""
|
106
|
+
Perform forward pass with pre-normalization.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
src (torch.Tensor): Input tensor.
|
110
|
+
src_mask (torch.Tensor, optional): Mask for the src sequence.
|
111
|
+
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
|
112
|
+
pos (torch.Tensor, optional): Positional encoding.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
(torch.Tensor): Output tensor after attention and feedforward.
|
116
|
+
"""
|
117
|
+
src2 = self.norm1(src)
|
118
|
+
q = k = self.with_pos_embed(src2, pos)
|
119
|
+
src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
120
|
+
src = src + self.dropout1(src2)
|
121
|
+
src2 = self.norm2(src)
|
122
|
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
|
123
|
+
return src + self.dropout2(src2)
|
124
|
+
|
125
|
+
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
126
|
+
"""
|
127
|
+
Forward propagates the input through the encoder module.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
src (torch.Tensor): Input tensor.
|
131
|
+
src_mask (torch.Tensor, optional): Mask for the src sequence.
|
132
|
+
src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
|
133
|
+
pos (torch.Tensor, optional): Positional encoding.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
(torch.Tensor): Output tensor after transformer encoder layer.
|
137
|
+
"""
|
138
|
+
if self.normalize_before:
|
139
|
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
140
|
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
141
|
+
|
142
|
+
|
143
|
+
class AIFI(TransformerEncoderLayer):
|
144
|
+
"""
|
145
|
+
Defines the AIFI transformer layer.
|
146
|
+
|
147
|
+
This class extends TransformerEncoderLayer to work with 2D data by adding positional embeddings.
|
148
|
+
"""
|
149
|
+
|
150
|
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
|
151
|
+
"""
|
152
|
+
Initialize the AIFI instance with specified parameters.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
c1 (int): Input dimension.
|
156
|
+
cm (int): Hidden dimension in the feedforward network.
|
157
|
+
num_heads (int): Number of attention heads.
|
158
|
+
dropout (float): Dropout probability.
|
159
|
+
act (nn.Module): Activation function.
|
160
|
+
normalize_before (bool): Whether to apply normalization before attention and feedforward.
|
161
|
+
"""
|
162
|
+
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
|
163
|
+
|
164
|
+
def forward(self, x):
|
165
|
+
"""
|
166
|
+
Forward pass for the AIFI transformer layer.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
x (torch.Tensor): Input tensor with shape [B, C, H, W].
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
(torch.Tensor): Output tensor with shape [B, C, H, W].
|
173
|
+
"""
|
174
|
+
c, h, w = x.shape[1:]
|
175
|
+
pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
|
176
|
+
# Flatten [B, C, H, W] to [B, HxW, C]
|
177
|
+
x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
|
178
|
+
return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
|
179
|
+
|
180
|
+
@staticmethod
|
181
|
+
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
|
182
|
+
"""
|
183
|
+
Build 2D sine-cosine position embedding.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
w (int): Width of the feature map.
|
187
|
+
h (int): Height of the feature map.
|
188
|
+
embed_dim (int): Embedding dimension.
|
189
|
+
temperature (float): Temperature for the sine/cosine functions.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
(torch.Tensor): Position embedding with shape [1, embed_dim, h*w].
|
193
|
+
"""
|
194
|
+
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
|
195
|
+
grid_w = torch.arange(w, dtype=torch.float32)
|
196
|
+
grid_h = torch.arange(h, dtype=torch.float32)
|
197
|
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
|
198
|
+
pos_dim = embed_dim // 4
|
199
|
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
200
|
+
omega = 1.0 / (temperature**omega)
|
201
|
+
|
202
|
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
203
|
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
204
|
+
|
205
|
+
return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
|
206
|
+
|
207
|
+
|
208
|
+
class TransformerLayer(nn.Module):
|
209
|
+
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
|
210
|
+
|
211
|
+
def __init__(self, c, num_heads):
|
212
|
+
"""
|
213
|
+
Initialize a self-attention mechanism using linear transformations and multi-head attention.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
c (int): Input and output channel dimension.
|
217
|
+
num_heads (int): Number of attention heads.
|
218
|
+
"""
|
219
|
+
super().__init__()
|
220
|
+
self.q = nn.Linear(c, c, bias=False)
|
221
|
+
self.k = nn.Linear(c, c, bias=False)
|
222
|
+
self.v = nn.Linear(c, c, bias=False)
|
223
|
+
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
224
|
+
self.fc1 = nn.Linear(c, c, bias=False)
|
225
|
+
self.fc2 = nn.Linear(c, c, bias=False)
|
226
|
+
|
227
|
+
def forward(self, x):
|
228
|
+
"""
|
229
|
+
Apply a transformer block to the input x and return the output.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
x (torch.Tensor): Input tensor.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
(torch.Tensor): Output tensor after transformer layer.
|
236
|
+
"""
|
237
|
+
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
238
|
+
return self.fc2(self.fc1(x)) + x
|
239
|
+
|
240
|
+
|
241
|
+
class TransformerBlock(nn.Module):
|
242
|
+
"""
|
243
|
+
Vision Transformer https://arxiv.org/abs/2010.11929.
|
244
|
+
|
245
|
+
Attributes:
|
246
|
+
conv (Conv, optional): Convolution layer if input and output channels differ.
|
247
|
+
linear (nn.Linear): Learnable position embedding.
|
248
|
+
tr (nn.Sequential): Sequential container of transformer layers.
|
249
|
+
c2 (int): Output channel dimension.
|
250
|
+
"""
|
251
|
+
|
252
|
+
def __init__(self, c1, c2, num_heads, num_layers):
|
253
|
+
"""
|
254
|
+
Initialize a Transformer module with position embedding and specified number of heads and layers.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
c1 (int): Input channel dimension.
|
258
|
+
c2 (int): Output channel dimension.
|
259
|
+
num_heads (int): Number of attention heads.
|
260
|
+
num_layers (int): Number of transformer layers.
|
261
|
+
"""
|
262
|
+
super().__init__()
|
263
|
+
self.conv = None
|
264
|
+
if c1 != c2:
|
265
|
+
self.conv = Conv(c1, c2)
|
266
|
+
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
267
|
+
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
268
|
+
self.c2 = c2
|
269
|
+
|
270
|
+
def forward(self, x):
|
271
|
+
"""
|
272
|
+
Forward propagates the input through the bottleneck module.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
x (torch.Tensor): Input tensor with shape [b, c1, w, h].
|
276
|
+
|
277
|
+
Returns:
|
278
|
+
(torch.Tensor): Output tensor with shape [b, c2, w, h].
|
279
|
+
"""
|
280
|
+
if self.conv is not None:
|
281
|
+
x = self.conv(x)
|
282
|
+
b, _, w, h = x.shape
|
283
|
+
p = x.flatten(2).permute(2, 0, 1)
|
284
|
+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
285
|
+
|
286
|
+
|
287
|
+
class MLPBlock(nn.Module):
|
288
|
+
"""Implements a single block of a multi-layer perceptron."""
|
289
|
+
|
290
|
+
def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
|
291
|
+
"""
|
292
|
+
Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
embedding_dim (int): Input and output dimension.
|
296
|
+
mlp_dim (int): Hidden dimension.
|
297
|
+
act (nn.Module): Activation function.
|
298
|
+
"""
|
299
|
+
super().__init__()
|
300
|
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
301
|
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
302
|
+
self.act = act()
|
303
|
+
|
304
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
305
|
+
"""
|
306
|
+
Forward pass for the MLPBlock.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
x (torch.Tensor): Input tensor.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
(torch.Tensor): Output tensor after MLP block.
|
313
|
+
"""
|
314
|
+
return self.lin2(self.act(self.lin1(x)))
|
315
|
+
|
316
|
+
|
317
|
+
class MLP(nn.Module):
|
318
|
+
"""
|
319
|
+
Implements a simple multi-layer perceptron (also called FFN).
|
320
|
+
|
321
|
+
Attributes:
|
322
|
+
num_layers (int): Number of layers in the MLP.
|
323
|
+
layers (nn.ModuleList): List of linear layers.
|
324
|
+
sigmoid (bool): Whether to apply sigmoid to the output.
|
325
|
+
act (nn.Module): Activation function.
|
326
|
+
"""
|
327
|
+
|
328
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False):
|
329
|
+
"""
|
330
|
+
Initialize the MLP with specified input, hidden, output dimensions and number of layers.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
input_dim (int): Input dimension.
|
334
|
+
hidden_dim (int): Hidden dimension.
|
335
|
+
output_dim (int): Output dimension.
|
336
|
+
num_layers (int): Number of layers.
|
337
|
+
act (nn.Module): Activation function.
|
338
|
+
sigmoid (bool): Whether to apply sigmoid to the output.
|
339
|
+
"""
|
340
|
+
super().__init__()
|
341
|
+
self.num_layers = num_layers
|
342
|
+
h = [hidden_dim] * (num_layers - 1)
|
343
|
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
344
|
+
self.sigmoid = sigmoid
|
345
|
+
self.act = act()
|
346
|
+
|
347
|
+
def forward(self, x):
|
348
|
+
"""
|
349
|
+
Forward pass for the entire MLP.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
x (torch.Tensor): Input tensor.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
(torch.Tensor): Output tensor after MLP.
|
356
|
+
"""
|
357
|
+
for i, layer in enumerate(self.layers):
|
358
|
+
x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
|
359
|
+
return x.sigmoid() if getattr(self, "sigmoid", False) else x
|
360
|
+
|
361
|
+
|
362
|
+
class LayerNorm2d(nn.Module):
|
363
|
+
"""
|
364
|
+
2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
|
365
|
+
|
366
|
+
Original implementations in
|
367
|
+
https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
|
368
|
+
and
|
369
|
+
https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py.
|
370
|
+
|
371
|
+
Attributes:
|
372
|
+
weight (nn.Parameter): Learnable scale parameter.
|
373
|
+
bias (nn.Parameter): Learnable bias parameter.
|
374
|
+
eps (float): Small constant for numerical stability.
|
375
|
+
"""
|
376
|
+
|
377
|
+
def __init__(self, num_channels, eps=1e-6):
|
378
|
+
"""
|
379
|
+
Initialize LayerNorm2d with the given parameters.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
num_channels (int): Number of channels in the input.
|
383
|
+
eps (float): Small constant for numerical stability.
|
384
|
+
"""
|
385
|
+
super().__init__()
|
386
|
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
387
|
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
388
|
+
self.eps = eps
|
389
|
+
|
390
|
+
def forward(self, x):
|
391
|
+
"""
|
392
|
+
Perform forward pass for 2D layer normalization.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
x (torch.Tensor): Input tensor.
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
(torch.Tensor): Normalized output tensor.
|
399
|
+
"""
|
400
|
+
u = x.mean(1, keepdim=True)
|
401
|
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
402
|
+
x = (x - u) / torch.sqrt(s + self.eps)
|
403
|
+
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
404
|
+
|
405
|
+
|
406
|
+
class MSDeformAttn(nn.Module):
|
407
|
+
"""
|
408
|
+
Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
|
409
|
+
|
410
|
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
|
411
|
+
|
412
|
+
Attributes:
|
413
|
+
im2col_step (int): Step size for im2col operations.
|
414
|
+
d_model (int): Model dimension.
|
415
|
+
n_levels (int): Number of feature levels.
|
416
|
+
n_heads (int): Number of attention heads.
|
417
|
+
n_points (int): Number of sampling points per attention head per feature level.
|
418
|
+
sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
|
419
|
+
attention_weights (nn.Linear): Linear layer for generating attention weights.
|
420
|
+
value_proj (nn.Linear): Linear layer for projecting values.
|
421
|
+
output_proj (nn.Linear): Linear layer for projecting output.
|
422
|
+
"""
|
423
|
+
|
424
|
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
425
|
+
"""
|
426
|
+
Initialize MSDeformAttn with the given parameters.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
d_model (int): Model dimension.
|
430
|
+
n_levels (int): Number of feature levels.
|
431
|
+
n_heads (int): Number of attention heads.
|
432
|
+
n_points (int): Number of sampling points per attention head per feature level.
|
433
|
+
"""
|
434
|
+
super().__init__()
|
435
|
+
if d_model % n_heads != 0:
|
436
|
+
raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
|
437
|
+
_d_per_head = d_model // n_heads
|
438
|
+
# Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
|
439
|
+
assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
|
440
|
+
|
441
|
+
self.im2col_step = 64
|
442
|
+
|
443
|
+
self.d_model = d_model
|
444
|
+
self.n_levels = n_levels
|
445
|
+
self.n_heads = n_heads
|
446
|
+
self.n_points = n_points
|
447
|
+
|
448
|
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
449
|
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
450
|
+
self.value_proj = nn.Linear(d_model, d_model)
|
451
|
+
self.output_proj = nn.Linear(d_model, d_model)
|
452
|
+
|
453
|
+
self._reset_parameters()
|
454
|
+
|
455
|
+
def _reset_parameters(self):
|
456
|
+
"""Reset module parameters."""
|
457
|
+
constant_(self.sampling_offsets.weight.data, 0.0)
|
458
|
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
459
|
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
460
|
+
grid_init = (
|
461
|
+
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
462
|
+
.view(self.n_heads, 1, 1, 2)
|
463
|
+
.repeat(1, self.n_levels, self.n_points, 1)
|
464
|
+
)
|
465
|
+
for i in range(self.n_points):
|
466
|
+
grid_init[:, :, i, :] *= i + 1
|
467
|
+
with torch.no_grad():
|
468
|
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
469
|
+
constant_(self.attention_weights.weight.data, 0.0)
|
470
|
+
constant_(self.attention_weights.bias.data, 0.0)
|
471
|
+
xavier_uniform_(self.value_proj.weight.data)
|
472
|
+
constant_(self.value_proj.bias.data, 0.0)
|
473
|
+
xavier_uniform_(self.output_proj.weight.data)
|
474
|
+
constant_(self.output_proj.bias.data, 0.0)
|
475
|
+
|
476
|
+
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
|
477
|
+
"""
|
478
|
+
Perform forward pass for multiscale deformable attention.
|
479
|
+
|
480
|
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
481
|
+
|
482
|
+
Args:
|
483
|
+
query (torch.Tensor): Tensor with shape [bs, query_length, C].
|
484
|
+
refer_bbox (torch.Tensor): Tensor with shape [bs, query_length, n_levels, 2], range in [0, 1],
|
485
|
+
top-left (0,0), bottom-right (1, 1), including padding area.
|
486
|
+
value (torch.Tensor): Tensor with shape [bs, value_length, C].
|
487
|
+
value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
|
488
|
+
value_mask (torch.Tensor, optional): Tensor with shape [bs, value_length], True for non-padding elements,
|
489
|
+
False for padding elements.
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
(torch.Tensor): Output tensor with shape [bs, Length_{query}, C].
|
493
|
+
"""
|
494
|
+
bs, len_q = query.shape[:2]
|
495
|
+
len_v = value.shape[1]
|
496
|
+
assert sum(s[0] * s[1] for s in value_shapes) == len_v
|
497
|
+
|
498
|
+
value = self.value_proj(value)
|
499
|
+
if value_mask is not None:
|
500
|
+
value = value.masked_fill(value_mask[..., None], float(0))
|
501
|
+
value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
|
502
|
+
sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
503
|
+
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
|
504
|
+
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
|
505
|
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
506
|
+
num_points = refer_bbox.shape[-1]
|
507
|
+
if num_points == 2:
|
508
|
+
offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
|
509
|
+
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
510
|
+
sampling_locations = refer_bbox[:, :, None, :, None, :] + add
|
511
|
+
elif num_points == 4:
|
512
|
+
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
|
513
|
+
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
|
514
|
+
else:
|
515
|
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
|
516
|
+
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
517
|
+
return self.output_proj(output)
|
518
|
+
|
519
|
+
|
520
|
+
class DeformableTransformerDecoderLayer(nn.Module):
|
521
|
+
"""
|
522
|
+
Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
|
523
|
+
|
524
|
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
525
|
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
|
526
|
+
|
527
|
+
Attributes:
|
528
|
+
self_attn (nn.MultiheadAttention): Self-attention module.
|
529
|
+
dropout1 (nn.Dropout): Dropout after self-attention.
|
530
|
+
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
531
|
+
cross_attn (MSDeformAttn): Cross-attention module.
|
532
|
+
dropout2 (nn.Dropout): Dropout after cross-attention.
|
533
|
+
norm2 (nn.LayerNorm): Layer normalization after cross-attention.
|
534
|
+
linear1 (nn.Linear): First linear layer in the feedforward network.
|
535
|
+
act (nn.Module): Activation function.
|
536
|
+
dropout3 (nn.Dropout): Dropout in the feedforward network.
|
537
|
+
linear2 (nn.Linear): Second linear layer in the feedforward network.
|
538
|
+
dropout4 (nn.Dropout): Dropout after the feedforward network.
|
539
|
+
norm3 (nn.LayerNorm): Layer normalization after the feedforward network.
|
540
|
+
"""
|
541
|
+
|
542
|
+
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4):
|
543
|
+
"""
|
544
|
+
Initialize the DeformableTransformerDecoderLayer with the given parameters.
|
545
|
+
|
546
|
+
Args:
|
547
|
+
d_model (int): Model dimension.
|
548
|
+
n_heads (int): Number of attention heads.
|
549
|
+
d_ffn (int): Dimension of the feedforward network.
|
550
|
+
dropout (float): Dropout probability.
|
551
|
+
act (nn.Module): Activation function.
|
552
|
+
n_levels (int): Number of feature levels.
|
553
|
+
n_points (int): Number of sampling points.
|
554
|
+
"""
|
555
|
+
super().__init__()
|
556
|
+
|
557
|
+
# Self attention
|
558
|
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
559
|
+
self.dropout1 = nn.Dropout(dropout)
|
560
|
+
self.norm1 = nn.LayerNorm(d_model)
|
561
|
+
|
562
|
+
# Cross attention
|
563
|
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
564
|
+
self.dropout2 = nn.Dropout(dropout)
|
565
|
+
self.norm2 = nn.LayerNorm(d_model)
|
566
|
+
|
567
|
+
# FFN
|
568
|
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
569
|
+
self.act = act
|
570
|
+
self.dropout3 = nn.Dropout(dropout)
|
571
|
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
572
|
+
self.dropout4 = nn.Dropout(dropout)
|
573
|
+
self.norm3 = nn.LayerNorm(d_model)
|
574
|
+
|
575
|
+
@staticmethod
|
576
|
+
def with_pos_embed(tensor, pos):
|
577
|
+
"""Add positional embeddings to the input tensor, if provided."""
|
578
|
+
return tensor if pos is None else tensor + pos
|
579
|
+
|
580
|
+
def forward_ffn(self, tgt):
|
581
|
+
"""
|
582
|
+
Perform forward pass through the Feed-Forward Network part of the layer.
|
583
|
+
|
584
|
+
Args:
|
585
|
+
tgt (torch.Tensor): Input tensor.
|
586
|
+
|
587
|
+
Returns:
|
588
|
+
(torch.Tensor): Output tensor after FFN.
|
589
|
+
"""
|
590
|
+
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
|
591
|
+
tgt = tgt + self.dropout4(tgt2)
|
592
|
+
return self.norm3(tgt)
|
593
|
+
|
594
|
+
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
595
|
+
"""
|
596
|
+
Perform the forward pass through the entire decoder layer.
|
597
|
+
|
598
|
+
Args:
|
599
|
+
embed (torch.Tensor): Input embeddings.
|
600
|
+
refer_bbox (torch.Tensor): Reference bounding boxes.
|
601
|
+
feats (torch.Tensor): Feature maps.
|
602
|
+
shapes (list): Feature shapes.
|
603
|
+
padding_mask (torch.Tensor, optional): Padding mask.
|
604
|
+
attn_mask (torch.Tensor, optional): Attention mask.
|
605
|
+
query_pos (torch.Tensor, optional): Query position embeddings.
|
606
|
+
|
607
|
+
Returns:
|
608
|
+
(torch.Tensor): Output tensor after decoder layer.
|
609
|
+
"""
|
610
|
+
# Self attention
|
611
|
+
q = k = self.with_pos_embed(embed, query_pos)
|
612
|
+
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
|
613
|
+
0
|
614
|
+
].transpose(0, 1)
|
615
|
+
embed = embed + self.dropout1(tgt)
|
616
|
+
embed = self.norm1(embed)
|
617
|
+
|
618
|
+
# Cross attention
|
619
|
+
tgt = self.cross_attn(
|
620
|
+
self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
|
621
|
+
)
|
622
|
+
embed = embed + self.dropout2(tgt)
|
623
|
+
embed = self.norm2(embed)
|
624
|
+
|
625
|
+
# FFN
|
626
|
+
return self.forward_ffn(embed)
|
627
|
+
|
628
|
+
|
629
|
+
class DeformableTransformerDecoder(nn.Module):
|
630
|
+
"""
|
631
|
+
Implementation of Deformable Transformer Decoder based on PaddleDetection.
|
632
|
+
|
633
|
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
634
|
+
|
635
|
+
Attributes:
|
636
|
+
layers (nn.ModuleList): List of decoder layers.
|
637
|
+
num_layers (int): Number of decoder layers.
|
638
|
+
hidden_dim (int): Hidden dimension.
|
639
|
+
eval_idx (int): Index of the layer to use during evaluation.
|
640
|
+
"""
|
641
|
+
|
642
|
+
def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
|
643
|
+
"""
|
644
|
+
Initialize the DeformableTransformerDecoder with the given parameters.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
hidden_dim (int): Hidden dimension.
|
648
|
+
decoder_layer (nn.Module): Decoder layer module.
|
649
|
+
num_layers (int): Number of decoder layers.
|
650
|
+
eval_idx (int): Index of the layer to use during evaluation.
|
651
|
+
"""
|
652
|
+
super().__init__()
|
653
|
+
self.layers = _get_clones(decoder_layer, num_layers)
|
654
|
+
self.num_layers = num_layers
|
655
|
+
self.hidden_dim = hidden_dim
|
656
|
+
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
657
|
+
|
658
|
+
def forward(
|
659
|
+
self,
|
660
|
+
embed, # decoder embeddings
|
661
|
+
refer_bbox, # anchor
|
662
|
+
feats, # image features
|
663
|
+
shapes, # feature shapes
|
664
|
+
bbox_head,
|
665
|
+
score_head,
|
666
|
+
pos_mlp,
|
667
|
+
attn_mask=None,
|
668
|
+
padding_mask=None,
|
669
|
+
):
|
670
|
+
"""
|
671
|
+
Perform the forward pass through the entire decoder.
|
672
|
+
|
673
|
+
Args:
|
674
|
+
embed (torch.Tensor): Decoder embeddings.
|
675
|
+
refer_bbox (torch.Tensor): Reference bounding boxes.
|
676
|
+
feats (torch.Tensor): Image features.
|
677
|
+
shapes (list): Feature shapes.
|
678
|
+
bbox_head (nn.Module): Bounding box prediction head.
|
679
|
+
score_head (nn.Module): Score prediction head.
|
680
|
+
pos_mlp (nn.Module): Position MLP.
|
681
|
+
attn_mask (torch.Tensor, optional): Attention mask.
|
682
|
+
padding_mask (torch.Tensor, optional): Padding mask.
|
683
|
+
|
684
|
+
Returns:
|
685
|
+
dec_bboxes (torch.Tensor): Decoded bounding boxes.
|
686
|
+
dec_cls (torch.Tensor): Decoded classification scores.
|
687
|
+
"""
|
688
|
+
output = embed
|
689
|
+
dec_bboxes = []
|
690
|
+
dec_cls = []
|
691
|
+
last_refined_bbox = None
|
692
|
+
refer_bbox = refer_bbox.sigmoid()
|
693
|
+
for i, layer in enumerate(self.layers):
|
694
|
+
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
695
|
+
|
696
|
+
bbox = bbox_head[i](output)
|
697
|
+
refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
|
698
|
+
|
699
|
+
if self.training:
|
700
|
+
dec_cls.append(score_head[i](output))
|
701
|
+
if i == 0:
|
702
|
+
dec_bboxes.append(refined_bbox)
|
703
|
+
else:
|
704
|
+
dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
|
705
|
+
elif i == self.eval_idx:
|
706
|
+
dec_cls.append(score_head[i](output))
|
707
|
+
dec_bboxes.append(refined_bbox)
|
708
|
+
break
|
709
|
+
|
710
|
+
last_refined_bbox = refined_bbox
|
711
|
+
refer_bbox = refined_bbox.detach() if self.training else refined_bbox
|
712
|
+
|
713
|
+
return torch.stack(dec_bboxes), torch.stack(dec_cls)
|