dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,713 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """Transformer modules."""
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.init import constant_, xavier_uniform_
10
+
11
+ from .conv import Conv
12
+ from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
13
+
14
+ __all__ = (
15
+ "TransformerEncoderLayer",
16
+ "TransformerLayer",
17
+ "TransformerBlock",
18
+ "MLPBlock",
19
+ "LayerNorm2d",
20
+ "AIFI",
21
+ "DeformableTransformerDecoder",
22
+ "DeformableTransformerDecoderLayer",
23
+ "MSDeformAttn",
24
+ "MLP",
25
+ )
26
+
27
+
28
+ class TransformerEncoderLayer(nn.Module):
29
+ """
30
+ Defines a single layer of the transformer encoder.
31
+
32
+ Attributes:
33
+ ma (nn.MultiheadAttention): Multi-head attention module.
34
+ fc1 (nn.Linear): First linear layer in the feedforward network.
35
+ fc2 (nn.Linear): Second linear layer in the feedforward network.
36
+ norm1 (nn.LayerNorm): Layer normalization after attention.
37
+ norm2 (nn.LayerNorm): Layer normalization after feedforward network.
38
+ dropout (nn.Dropout): Dropout layer for the feedforward network.
39
+ dropout1 (nn.Dropout): Dropout layer after attention.
40
+ dropout2 (nn.Dropout): Dropout layer after feedforward network.
41
+ act (nn.Module): Activation function.
42
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
43
+ """
44
+
45
+ def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
46
+ """
47
+ Initialize the TransformerEncoderLayer with specified parameters.
48
+
49
+ Args:
50
+ c1 (int): Input dimension.
51
+ cm (int): Hidden dimension in the feedforward network.
52
+ num_heads (int): Number of attention heads.
53
+ dropout (float): Dropout probability.
54
+ act (nn.Module): Activation function.
55
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
56
+ """
57
+ super().__init__()
58
+ from ...utils.torch_utils import TORCH_1_9
59
+
60
+ if not TORCH_1_9:
61
+ raise ModuleNotFoundError(
62
+ "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
63
+ )
64
+ self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
65
+ # Implementation of Feedforward model
66
+ self.fc1 = nn.Linear(c1, cm)
67
+ self.fc2 = nn.Linear(cm, c1)
68
+
69
+ self.norm1 = nn.LayerNorm(c1)
70
+ self.norm2 = nn.LayerNorm(c1)
71
+ self.dropout = nn.Dropout(dropout)
72
+ self.dropout1 = nn.Dropout(dropout)
73
+ self.dropout2 = nn.Dropout(dropout)
74
+
75
+ self.act = act
76
+ self.normalize_before = normalize_before
77
+
78
+ @staticmethod
79
+ def with_pos_embed(tensor, pos=None):
80
+ """Add position embeddings to the tensor if provided."""
81
+ return tensor if pos is None else tensor + pos
82
+
83
+ def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
84
+ """
85
+ Perform forward pass with post-normalization.
86
+
87
+ Args:
88
+ src (torch.Tensor): Input tensor.
89
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
90
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
91
+ pos (torch.Tensor, optional): Positional encoding.
92
+
93
+ Returns:
94
+ (torch.Tensor): Output tensor after attention and feedforward.
95
+ """
96
+ q = k = self.with_pos_embed(src, pos)
97
+ src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
98
+ src = src + self.dropout1(src2)
99
+ src = self.norm1(src)
100
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
101
+ src = src + self.dropout2(src2)
102
+ return self.norm2(src)
103
+
104
+ def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
105
+ """
106
+ Perform forward pass with pre-normalization.
107
+
108
+ Args:
109
+ src (torch.Tensor): Input tensor.
110
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
111
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
112
+ pos (torch.Tensor, optional): Positional encoding.
113
+
114
+ Returns:
115
+ (torch.Tensor): Output tensor after attention and feedforward.
116
+ """
117
+ src2 = self.norm1(src)
118
+ q = k = self.with_pos_embed(src2, pos)
119
+ src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
120
+ src = src + self.dropout1(src2)
121
+ src2 = self.norm2(src)
122
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
123
+ return src + self.dropout2(src2)
124
+
125
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
126
+ """
127
+ Forward propagates the input through the encoder module.
128
+
129
+ Args:
130
+ src (torch.Tensor): Input tensor.
131
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
132
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
133
+ pos (torch.Tensor, optional): Positional encoding.
134
+
135
+ Returns:
136
+ (torch.Tensor): Output tensor after transformer encoder layer.
137
+ """
138
+ if self.normalize_before:
139
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
140
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
141
+
142
+
143
+ class AIFI(TransformerEncoderLayer):
144
+ """
145
+ Defines the AIFI transformer layer.
146
+
147
+ This class extends TransformerEncoderLayer to work with 2D data by adding positional embeddings.
148
+ """
149
+
150
+ def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
151
+ """
152
+ Initialize the AIFI instance with specified parameters.
153
+
154
+ Args:
155
+ c1 (int): Input dimension.
156
+ cm (int): Hidden dimension in the feedforward network.
157
+ num_heads (int): Number of attention heads.
158
+ dropout (float): Dropout probability.
159
+ act (nn.Module): Activation function.
160
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
161
+ """
162
+ super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
163
+
164
+ def forward(self, x):
165
+ """
166
+ Forward pass for the AIFI transformer layer.
167
+
168
+ Args:
169
+ x (torch.Tensor): Input tensor with shape [B, C, H, W].
170
+
171
+ Returns:
172
+ (torch.Tensor): Output tensor with shape [B, C, H, W].
173
+ """
174
+ c, h, w = x.shape[1:]
175
+ pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
176
+ # Flatten [B, C, H, W] to [B, HxW, C]
177
+ x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
178
+ return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
179
+
180
+ @staticmethod
181
+ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
182
+ """
183
+ Build 2D sine-cosine position embedding.
184
+
185
+ Args:
186
+ w (int): Width of the feature map.
187
+ h (int): Height of the feature map.
188
+ embed_dim (int): Embedding dimension.
189
+ temperature (float): Temperature for the sine/cosine functions.
190
+
191
+ Returns:
192
+ (torch.Tensor): Position embedding with shape [1, embed_dim, h*w].
193
+ """
194
+ assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
195
+ grid_w = torch.arange(w, dtype=torch.float32)
196
+ grid_h = torch.arange(h, dtype=torch.float32)
197
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
198
+ pos_dim = embed_dim // 4
199
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
200
+ omega = 1.0 / (temperature**omega)
201
+
202
+ out_w = grid_w.flatten()[..., None] @ omega[None]
203
+ out_h = grid_h.flatten()[..., None] @ omega[None]
204
+
205
+ return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
206
+
207
+
208
+ class TransformerLayer(nn.Module):
209
+ """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
210
+
211
+ def __init__(self, c, num_heads):
212
+ """
213
+ Initialize a self-attention mechanism using linear transformations and multi-head attention.
214
+
215
+ Args:
216
+ c (int): Input and output channel dimension.
217
+ num_heads (int): Number of attention heads.
218
+ """
219
+ super().__init__()
220
+ self.q = nn.Linear(c, c, bias=False)
221
+ self.k = nn.Linear(c, c, bias=False)
222
+ self.v = nn.Linear(c, c, bias=False)
223
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
224
+ self.fc1 = nn.Linear(c, c, bias=False)
225
+ self.fc2 = nn.Linear(c, c, bias=False)
226
+
227
+ def forward(self, x):
228
+ """
229
+ Apply a transformer block to the input x and return the output.
230
+
231
+ Args:
232
+ x (torch.Tensor): Input tensor.
233
+
234
+ Returns:
235
+ (torch.Tensor): Output tensor after transformer layer.
236
+ """
237
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
238
+ return self.fc2(self.fc1(x)) + x
239
+
240
+
241
+ class TransformerBlock(nn.Module):
242
+ """
243
+ Vision Transformer https://arxiv.org/abs/2010.11929.
244
+
245
+ Attributes:
246
+ conv (Conv, optional): Convolution layer if input and output channels differ.
247
+ linear (nn.Linear): Learnable position embedding.
248
+ tr (nn.Sequential): Sequential container of transformer layers.
249
+ c2 (int): Output channel dimension.
250
+ """
251
+
252
+ def __init__(self, c1, c2, num_heads, num_layers):
253
+ """
254
+ Initialize a Transformer module with position embedding and specified number of heads and layers.
255
+
256
+ Args:
257
+ c1 (int): Input channel dimension.
258
+ c2 (int): Output channel dimension.
259
+ num_heads (int): Number of attention heads.
260
+ num_layers (int): Number of transformer layers.
261
+ """
262
+ super().__init__()
263
+ self.conv = None
264
+ if c1 != c2:
265
+ self.conv = Conv(c1, c2)
266
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
267
+ self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
268
+ self.c2 = c2
269
+
270
+ def forward(self, x):
271
+ """
272
+ Forward propagates the input through the bottleneck module.
273
+
274
+ Args:
275
+ x (torch.Tensor): Input tensor with shape [b, c1, w, h].
276
+
277
+ Returns:
278
+ (torch.Tensor): Output tensor with shape [b, c2, w, h].
279
+ """
280
+ if self.conv is not None:
281
+ x = self.conv(x)
282
+ b, _, w, h = x.shape
283
+ p = x.flatten(2).permute(2, 0, 1)
284
+ return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
285
+
286
+
287
+ class MLPBlock(nn.Module):
288
+ """Implements a single block of a multi-layer perceptron."""
289
+
290
+ def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
291
+ """
292
+ Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
293
+
294
+ Args:
295
+ embedding_dim (int): Input and output dimension.
296
+ mlp_dim (int): Hidden dimension.
297
+ act (nn.Module): Activation function.
298
+ """
299
+ super().__init__()
300
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
301
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
302
+ self.act = act()
303
+
304
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
305
+ """
306
+ Forward pass for the MLPBlock.
307
+
308
+ Args:
309
+ x (torch.Tensor): Input tensor.
310
+
311
+ Returns:
312
+ (torch.Tensor): Output tensor after MLP block.
313
+ """
314
+ return self.lin2(self.act(self.lin1(x)))
315
+
316
+
317
+ class MLP(nn.Module):
318
+ """
319
+ Implements a simple multi-layer perceptron (also called FFN).
320
+
321
+ Attributes:
322
+ num_layers (int): Number of layers in the MLP.
323
+ layers (nn.ModuleList): List of linear layers.
324
+ sigmoid (bool): Whether to apply sigmoid to the output.
325
+ act (nn.Module): Activation function.
326
+ """
327
+
328
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False):
329
+ """
330
+ Initialize the MLP with specified input, hidden, output dimensions and number of layers.
331
+
332
+ Args:
333
+ input_dim (int): Input dimension.
334
+ hidden_dim (int): Hidden dimension.
335
+ output_dim (int): Output dimension.
336
+ num_layers (int): Number of layers.
337
+ act (nn.Module): Activation function.
338
+ sigmoid (bool): Whether to apply sigmoid to the output.
339
+ """
340
+ super().__init__()
341
+ self.num_layers = num_layers
342
+ h = [hidden_dim] * (num_layers - 1)
343
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
344
+ self.sigmoid = sigmoid
345
+ self.act = act()
346
+
347
+ def forward(self, x):
348
+ """
349
+ Forward pass for the entire MLP.
350
+
351
+ Args:
352
+ x (torch.Tensor): Input tensor.
353
+
354
+ Returns:
355
+ (torch.Tensor): Output tensor after MLP.
356
+ """
357
+ for i, layer in enumerate(self.layers):
358
+ x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
359
+ return x.sigmoid() if getattr(self, "sigmoid", False) else x
360
+
361
+
362
+ class LayerNorm2d(nn.Module):
363
+ """
364
+ 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
365
+
366
+ Original implementations in
367
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
368
+ and
369
+ https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py.
370
+
371
+ Attributes:
372
+ weight (nn.Parameter): Learnable scale parameter.
373
+ bias (nn.Parameter): Learnable bias parameter.
374
+ eps (float): Small constant for numerical stability.
375
+ """
376
+
377
+ def __init__(self, num_channels, eps=1e-6):
378
+ """
379
+ Initialize LayerNorm2d with the given parameters.
380
+
381
+ Args:
382
+ num_channels (int): Number of channels in the input.
383
+ eps (float): Small constant for numerical stability.
384
+ """
385
+ super().__init__()
386
+ self.weight = nn.Parameter(torch.ones(num_channels))
387
+ self.bias = nn.Parameter(torch.zeros(num_channels))
388
+ self.eps = eps
389
+
390
+ def forward(self, x):
391
+ """
392
+ Perform forward pass for 2D layer normalization.
393
+
394
+ Args:
395
+ x (torch.Tensor): Input tensor.
396
+
397
+ Returns:
398
+ (torch.Tensor): Normalized output tensor.
399
+ """
400
+ u = x.mean(1, keepdim=True)
401
+ s = (x - u).pow(2).mean(1, keepdim=True)
402
+ x = (x - u) / torch.sqrt(s + self.eps)
403
+ return self.weight[:, None, None] * x + self.bias[:, None, None]
404
+
405
+
406
+ class MSDeformAttn(nn.Module):
407
+ """
408
+ Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
409
+
410
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
411
+
412
+ Attributes:
413
+ im2col_step (int): Step size for im2col operations.
414
+ d_model (int): Model dimension.
415
+ n_levels (int): Number of feature levels.
416
+ n_heads (int): Number of attention heads.
417
+ n_points (int): Number of sampling points per attention head per feature level.
418
+ sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
419
+ attention_weights (nn.Linear): Linear layer for generating attention weights.
420
+ value_proj (nn.Linear): Linear layer for projecting values.
421
+ output_proj (nn.Linear): Linear layer for projecting output.
422
+ """
423
+
424
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
425
+ """
426
+ Initialize MSDeformAttn with the given parameters.
427
+
428
+ Args:
429
+ d_model (int): Model dimension.
430
+ n_levels (int): Number of feature levels.
431
+ n_heads (int): Number of attention heads.
432
+ n_points (int): Number of sampling points per attention head per feature level.
433
+ """
434
+ super().__init__()
435
+ if d_model % n_heads != 0:
436
+ raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
437
+ _d_per_head = d_model // n_heads
438
+ # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
439
+ assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
440
+
441
+ self.im2col_step = 64
442
+
443
+ self.d_model = d_model
444
+ self.n_levels = n_levels
445
+ self.n_heads = n_heads
446
+ self.n_points = n_points
447
+
448
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
449
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
450
+ self.value_proj = nn.Linear(d_model, d_model)
451
+ self.output_proj = nn.Linear(d_model, d_model)
452
+
453
+ self._reset_parameters()
454
+
455
+ def _reset_parameters(self):
456
+ """Reset module parameters."""
457
+ constant_(self.sampling_offsets.weight.data, 0.0)
458
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
459
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
460
+ grid_init = (
461
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
462
+ .view(self.n_heads, 1, 1, 2)
463
+ .repeat(1, self.n_levels, self.n_points, 1)
464
+ )
465
+ for i in range(self.n_points):
466
+ grid_init[:, :, i, :] *= i + 1
467
+ with torch.no_grad():
468
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
469
+ constant_(self.attention_weights.weight.data, 0.0)
470
+ constant_(self.attention_weights.bias.data, 0.0)
471
+ xavier_uniform_(self.value_proj.weight.data)
472
+ constant_(self.value_proj.bias.data, 0.0)
473
+ xavier_uniform_(self.output_proj.weight.data)
474
+ constant_(self.output_proj.bias.data, 0.0)
475
+
476
+ def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
477
+ """
478
+ Perform forward pass for multiscale deformable attention.
479
+
480
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
481
+
482
+ Args:
483
+ query (torch.Tensor): Tensor with shape [bs, query_length, C].
484
+ refer_bbox (torch.Tensor): Tensor with shape [bs, query_length, n_levels, 2], range in [0, 1],
485
+ top-left (0,0), bottom-right (1, 1), including padding area.
486
+ value (torch.Tensor): Tensor with shape [bs, value_length, C].
487
+ value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
488
+ value_mask (torch.Tensor, optional): Tensor with shape [bs, value_length], True for non-padding elements,
489
+ False for padding elements.
490
+
491
+ Returns:
492
+ (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].
493
+ """
494
+ bs, len_q = query.shape[:2]
495
+ len_v = value.shape[1]
496
+ assert sum(s[0] * s[1] for s in value_shapes) == len_v
497
+
498
+ value = self.value_proj(value)
499
+ if value_mask is not None:
500
+ value = value.masked_fill(value_mask[..., None], float(0))
501
+ value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
502
+ sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
503
+ attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
504
+ attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
505
+ # N, Len_q, n_heads, n_levels, n_points, 2
506
+ num_points = refer_bbox.shape[-1]
507
+ if num_points == 2:
508
+ offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
509
+ add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
510
+ sampling_locations = refer_bbox[:, :, None, :, None, :] + add
511
+ elif num_points == 4:
512
+ add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
513
+ sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
514
+ else:
515
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
516
+ output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
517
+ return self.output_proj(output)
518
+
519
+
520
+ class DeformableTransformerDecoderLayer(nn.Module):
521
+ """
522
+ Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
523
+
524
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
525
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
526
+
527
+ Attributes:
528
+ self_attn (nn.MultiheadAttention): Self-attention module.
529
+ dropout1 (nn.Dropout): Dropout after self-attention.
530
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
531
+ cross_attn (MSDeformAttn): Cross-attention module.
532
+ dropout2 (nn.Dropout): Dropout after cross-attention.
533
+ norm2 (nn.LayerNorm): Layer normalization after cross-attention.
534
+ linear1 (nn.Linear): First linear layer in the feedforward network.
535
+ act (nn.Module): Activation function.
536
+ dropout3 (nn.Dropout): Dropout in the feedforward network.
537
+ linear2 (nn.Linear): Second linear layer in the feedforward network.
538
+ dropout4 (nn.Dropout): Dropout after the feedforward network.
539
+ norm3 (nn.LayerNorm): Layer normalization after the feedforward network.
540
+ """
541
+
542
+ def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4):
543
+ """
544
+ Initialize the DeformableTransformerDecoderLayer with the given parameters.
545
+
546
+ Args:
547
+ d_model (int): Model dimension.
548
+ n_heads (int): Number of attention heads.
549
+ d_ffn (int): Dimension of the feedforward network.
550
+ dropout (float): Dropout probability.
551
+ act (nn.Module): Activation function.
552
+ n_levels (int): Number of feature levels.
553
+ n_points (int): Number of sampling points.
554
+ """
555
+ super().__init__()
556
+
557
+ # Self attention
558
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
559
+ self.dropout1 = nn.Dropout(dropout)
560
+ self.norm1 = nn.LayerNorm(d_model)
561
+
562
+ # Cross attention
563
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
564
+ self.dropout2 = nn.Dropout(dropout)
565
+ self.norm2 = nn.LayerNorm(d_model)
566
+
567
+ # FFN
568
+ self.linear1 = nn.Linear(d_model, d_ffn)
569
+ self.act = act
570
+ self.dropout3 = nn.Dropout(dropout)
571
+ self.linear2 = nn.Linear(d_ffn, d_model)
572
+ self.dropout4 = nn.Dropout(dropout)
573
+ self.norm3 = nn.LayerNorm(d_model)
574
+
575
+ @staticmethod
576
+ def with_pos_embed(tensor, pos):
577
+ """Add positional embeddings to the input tensor, if provided."""
578
+ return tensor if pos is None else tensor + pos
579
+
580
+ def forward_ffn(self, tgt):
581
+ """
582
+ Perform forward pass through the Feed-Forward Network part of the layer.
583
+
584
+ Args:
585
+ tgt (torch.Tensor): Input tensor.
586
+
587
+ Returns:
588
+ (torch.Tensor): Output tensor after FFN.
589
+ """
590
+ tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
591
+ tgt = tgt + self.dropout4(tgt2)
592
+ return self.norm3(tgt)
593
+
594
+ def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
595
+ """
596
+ Perform the forward pass through the entire decoder layer.
597
+
598
+ Args:
599
+ embed (torch.Tensor): Input embeddings.
600
+ refer_bbox (torch.Tensor): Reference bounding boxes.
601
+ feats (torch.Tensor): Feature maps.
602
+ shapes (list): Feature shapes.
603
+ padding_mask (torch.Tensor, optional): Padding mask.
604
+ attn_mask (torch.Tensor, optional): Attention mask.
605
+ query_pos (torch.Tensor, optional): Query position embeddings.
606
+
607
+ Returns:
608
+ (torch.Tensor): Output tensor after decoder layer.
609
+ """
610
+ # Self attention
611
+ q = k = self.with_pos_embed(embed, query_pos)
612
+ tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
613
+ 0
614
+ ].transpose(0, 1)
615
+ embed = embed + self.dropout1(tgt)
616
+ embed = self.norm1(embed)
617
+
618
+ # Cross attention
619
+ tgt = self.cross_attn(
620
+ self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
621
+ )
622
+ embed = embed + self.dropout2(tgt)
623
+ embed = self.norm2(embed)
624
+
625
+ # FFN
626
+ return self.forward_ffn(embed)
627
+
628
+
629
+ class DeformableTransformerDecoder(nn.Module):
630
+ """
631
+ Implementation of Deformable Transformer Decoder based on PaddleDetection.
632
+
633
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
634
+
635
+ Attributes:
636
+ layers (nn.ModuleList): List of decoder layers.
637
+ num_layers (int): Number of decoder layers.
638
+ hidden_dim (int): Hidden dimension.
639
+ eval_idx (int): Index of the layer to use during evaluation.
640
+ """
641
+
642
+ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
643
+ """
644
+ Initialize the DeformableTransformerDecoder with the given parameters.
645
+
646
+ Args:
647
+ hidden_dim (int): Hidden dimension.
648
+ decoder_layer (nn.Module): Decoder layer module.
649
+ num_layers (int): Number of decoder layers.
650
+ eval_idx (int): Index of the layer to use during evaluation.
651
+ """
652
+ super().__init__()
653
+ self.layers = _get_clones(decoder_layer, num_layers)
654
+ self.num_layers = num_layers
655
+ self.hidden_dim = hidden_dim
656
+ self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
657
+
658
+ def forward(
659
+ self,
660
+ embed, # decoder embeddings
661
+ refer_bbox, # anchor
662
+ feats, # image features
663
+ shapes, # feature shapes
664
+ bbox_head,
665
+ score_head,
666
+ pos_mlp,
667
+ attn_mask=None,
668
+ padding_mask=None,
669
+ ):
670
+ """
671
+ Perform the forward pass through the entire decoder.
672
+
673
+ Args:
674
+ embed (torch.Tensor): Decoder embeddings.
675
+ refer_bbox (torch.Tensor): Reference bounding boxes.
676
+ feats (torch.Tensor): Image features.
677
+ shapes (list): Feature shapes.
678
+ bbox_head (nn.Module): Bounding box prediction head.
679
+ score_head (nn.Module): Score prediction head.
680
+ pos_mlp (nn.Module): Position MLP.
681
+ attn_mask (torch.Tensor, optional): Attention mask.
682
+ padding_mask (torch.Tensor, optional): Padding mask.
683
+
684
+ Returns:
685
+ dec_bboxes (torch.Tensor): Decoded bounding boxes.
686
+ dec_cls (torch.Tensor): Decoded classification scores.
687
+ """
688
+ output = embed
689
+ dec_bboxes = []
690
+ dec_cls = []
691
+ last_refined_bbox = None
692
+ refer_bbox = refer_bbox.sigmoid()
693
+ for i, layer in enumerate(self.layers):
694
+ output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
695
+
696
+ bbox = bbox_head[i](output)
697
+ refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
698
+
699
+ if self.training:
700
+ dec_cls.append(score_head[i](output))
701
+ if i == 0:
702
+ dec_bboxes.append(refined_bbox)
703
+ else:
704
+ dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
705
+ elif i == self.eval_idx:
706
+ dec_cls.append(score_head[i](output))
707
+ dec_bboxes.append(refined_bbox)
708
+ break
709
+
710
+ last_refined_bbox = refined_bbox
711
+ refer_bbox = refined_bbox.detach() if self.training else refined_bbox
712
+
713
+ return torch.stack(dec_bboxes), torch.stack(dec_cls)